diff --git a/.Rbuildignore b/.Rbuildignore index 7fba94d70..09d152920 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -22,3 +22,4 @@ derby.log ^vignettes/articles$ ^[\.]?air\.toml$ ^\.vscode$ +^[.]?air[.]toml$ diff --git a/.vscode/settings.json b/.vscode/settings.json index f2d0b79d6..a9f69fe41 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,9 @@ "[r]": { "editor.formatOnSave": true, "editor.defaultFormatter": "Posit.air-vscode" + }, + "[quarto]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "quarto.quarto" } } diff --git a/R/aaa-import-standalone-obj-type.R b/R/aaa-import-standalone-obj-type.R index 8e3c07df4..35e88ca35 100644 --- a/R/aaa-import-standalone-obj-type.R +++ b/R/aaa-import-standalone-obj-type.R @@ -89,12 +89,11 @@ obj_type_friendly <- function(x, value = TRUE) { typeof(x), logical = "`NA`", integer = "an integer `NA`", - double = - if (is.nan(x)) { - "`NaN`" - } else { - "a numeric `NA`" - }, + double = if (is.nan(x)) { + "`NaN`" + } else { + "a numeric `NA`" + }, complex = "a complex `NA`", character = "a character `NA`", .rlang_stop_unexpected_typeof(x) @@ -296,14 +295,16 @@ obj_type_oo <- function(x) { #' @param ... Arguments passed to [abort()]. #' @inheritParams args_error_context #' @noRd -stop_input_type <- function(x, - what, - ..., - allow_na = FALSE, - allow_null = FALSE, - show_value = TRUE, - arg = caller_arg(x), - call = caller_env()) { +stop_input_type <- function( + x, + what, + ..., + allow_na = FALSE, + allow_null = FALSE, + show_value = TRUE, + arg = caller_arg(x), + call = caller_env() +) { # From standalone-cli.R cli <- env_get_list( nms = c("format_arg", "format_code"), diff --git a/R/aaa-import-standalone-types-check.R b/R/aaa-import-standalone-types-check.R index 22ea57ba8..59a25d815 100644 --- a/R/aaa-import-standalone-types-check.R +++ b/R/aaa-import-standalone-types-check.R @@ -60,13 +60,23 @@ .standalone_types_check_dot_call <- .Call -check_bool <- function(x, - ..., - allow_na = FALSE, - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { - if (!missing(x) && .standalone_types_check_dot_call(ffi_standalone_is_bool_1.0.7, x, allow_na, allow_null)) { +check_bool <- function( + x, + ..., + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { + if ( + !missing(x) && + .standalone_types_check_dot_call( + ffi_standalone_is_bool_1.0.7, + x, + allow_na, + allow_null + ) + ) { return(invisible(NULL)) } @@ -81,13 +91,15 @@ check_bool <- function(x, ) } -check_string <- function(x, - ..., - allow_empty = TRUE, - allow_na = FALSE, - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_string <- function( + x, + ..., + allow_empty = TRUE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { is_string <- .rlang_check_is_string( x, @@ -111,10 +123,7 @@ check_string <- function(x, ) } -.rlang_check_is_string <- function(x, - allow_empty, - allow_na, - allow_null) { +.rlang_check_is_string <- function(x, allow_empty, allow_na, allow_null) { if (is_string(x)) { if (allow_empty || !is_string(x, "")) { return(TRUE) @@ -132,11 +141,13 @@ check_string <- function(x, FALSE } -check_name <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_name <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { is_string <- .rlang_check_is_string( x, @@ -164,27 +175,32 @@ IS_NUMBER_true <- 0 IS_NUMBER_false <- 1 IS_NUMBER_oob <- 2 -check_number_decimal <- function(x, - ..., - min = NULL, - max = NULL, - allow_infinite = TRUE, - allow_na = FALSE, - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_number_decimal <- function( + x, + ..., + min = NULL, + max = NULL, + allow_infinite = TRUE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (missing(x)) { exit_code <- IS_NUMBER_false - } else if (0 == (exit_code <- .standalone_types_check_dot_call( - ffi_standalone_check_number_1.0.7, - x, - allow_decimal = TRUE, - min, - max, - allow_infinite, - allow_na, - allow_null - ))) { + } else if ( + 0 == + (exit_code <- .standalone_types_check_dot_call( + ffi_standalone_check_number_1.0.7, + x, + allow_decimal = TRUE, + min, + max, + allow_infinite, + allow_na, + allow_null + )) + ) { return(invisible(NULL)) } @@ -202,27 +218,32 @@ check_number_decimal <- function(x, ) } -check_number_whole <- function(x, - ..., - min = NULL, - max = NULL, - allow_infinite = FALSE, - allow_na = FALSE, - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_number_whole <- function( + x, + ..., + min = NULL, + max = NULL, + allow_infinite = FALSE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (missing(x)) { exit_code <- IS_NUMBER_false - } else if (0 == (exit_code <- .standalone_types_check_dot_call( - ffi_standalone_check_number_1.0.7, - x, - allow_decimal = FALSE, - min, - max, - allow_infinite, - allow_na, - allow_null - ))) { + } else if ( + 0 == + (exit_code <- .standalone_types_check_dot_call( + ffi_standalone_check_number_1.0.7, + x, + allow_decimal = FALSE, + min, + max, + allow_infinite, + allow_na, + allow_null + )) + ) { return(invisible(NULL)) } @@ -240,16 +261,18 @@ check_number_whole <- function(x, ) } -.stop_not_number <- function(x, - ..., - exit_code, - allow_decimal, - min, - max, - allow_na, - allow_null, - arg, - call) { +.stop_not_number <- function( + x, + ..., + exit_code, + allow_decimal, + min, + max, + allow_na, + allow_null, + arg, + call +) { if (allow_decimal) { what <- "a number" } else { @@ -282,11 +305,13 @@ check_number_whole <- function(x, ) } -check_symbol <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_symbol <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_symbol(x)) { return(invisible(NULL)) @@ -307,11 +332,13 @@ check_symbol <- function(x, ) } -check_arg <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_arg <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_symbol(x)) { return(invisible(NULL)) @@ -332,11 +359,13 @@ check_arg <- function(x, ) } -check_call <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_call <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_call(x)) { return(invisible(NULL)) @@ -357,11 +386,13 @@ check_call <- function(x, ) } -check_environment <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_environment <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_environment(x)) { return(invisible(NULL)) @@ -382,11 +413,13 @@ check_environment <- function(x, ) } -check_function <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_function <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_function(x)) { return(invisible(NULL)) @@ -407,11 +440,13 @@ check_function <- function(x, ) } -check_closure <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_closure <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_closure(x)) { return(invisible(NULL)) @@ -432,11 +467,13 @@ check_closure <- function(x, ) } -check_formula <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_formula <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_formula(x)) { return(invisible(NULL)) @@ -462,13 +499,14 @@ check_formula <- function(x, # TODO: Figure out what to do with logical `NA` and `allow_na = TRUE` -check_character <- function(x, - ..., - allow_na = TRUE, - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { - +check_character <- function( + x, + ..., + allow_na = TRUE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_character(x)) { if (!allow_na && any(is.na(x))) { @@ -497,11 +535,13 @@ check_character <- function(x, ) } -check_logical <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_logical <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is_logical(x)) { return(invisible(NULL)) @@ -522,11 +562,13 @@ check_logical <- function(x, ) } -check_data_frame <- function(x, - ..., - allow_null = FALSE, - arg = caller_arg(x), - call = caller_env()) { +check_data_frame <- function( + x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { if (!missing(x)) { if (is.data.frame(x)) { return(invisible(NULL)) diff --git a/R/aaa.R b/R/aaa.R index b58842f20..2852c9313 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -1,13 +1,11 @@ - maybe_multivariate <- function(results, object) { - if (isTRUE(ncol(results) > 1)) { nms <- colnames(results) results <- as_tibble(results, .name_repair = "minimal") if (length(nms) == 0 && length(object$preproc$y_var) == ncol(results)) { names(results) <- object$preproc$y_var } - } else { + } else { results <- unname(results[, 1]) } results @@ -33,7 +31,7 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { # used by logistic_reg() and gen_additive_mod() logistic_lp_to_conf_int <- function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 const <- stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) trans <- object$fit$family$linkinv @@ -51,35 +49,36 @@ logistic_lp_to_conf_int <- function(results, object) { colnames(res_2) <- c(lo_nms[2], hi_nms[2]) res <- bind_cols(res_1, res_2) - if (object$spec$method$pred$conf_int$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) { res$.std_error <- results$se.fit + } res } # used by linear_reg() and gen_additive_mod() linear_lp_to_conf_int <- -function(results, object) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 - const <- - stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res <- - tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) - ) - # In case of inverse or other links - if (any(res$.pred_upper < res$.pred_lower)) { - nms <- names(res) - res <- res[, 2:1] - names(res) <- nms - } + function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 + const <- + stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res <- + tibble( + .pred_lower = trans(results$fit - const * results$se.fit), + .pred_upper = trans(results$fit + const * results$se.fit) + ) + # In case of inverse or other links + if (any(res$.pred_upper < res$.pred_lower)) { + nms <- names(res) + res <- res[, 2:1] + names(res) <- nms + } - if (object$spec$method$pred$conf_int$extras$std_error) { - res$.std_error <- results$se.fit + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- results$se.fit + } + res } - res -} combine_words <- function(x) { if (isTRUE(length(x) > 2)) { diff --git a/R/aaa_models.R b/R/aaa_models.R index ee18048bb..0789b838d 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -1,6 +1,11 @@ # Initialize model environments -all_modes <- c("classification", "regression", "censored regression", "quantile regression") +all_modes <- c( + "classification", + "regression", + "censored regression", + "quantile regression" +) # ------------------------------------------------------------------------------ @@ -32,8 +37,19 @@ parsnip$modes <- c(all_modes, "unknown") # ------------------------------------------------------------------------------ pred_types <- - c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile", - "time", "survival", "linear_pred", "hazard") + c( + "raw", + "numeric", + "class", + "prob", + "conf_int", + "pred_int", + "quantile", + "time", + "survival", + "linear_pred", + "hazard" + ) # ------------------------------------------------------------------------------ @@ -96,8 +112,10 @@ error_set_object <- function(object, func) { "`{func}()` expected a model specification to be supplied to the \\ `object` argument, but received a(n) `{class(object)[1]}` object." - if (inherits(object, "function") && - isTRUE(environment(object)$.packageName == "parsnip")) { + if ( + inherits(object, "function") && + isTRUE(environment(object)$.packageName == "parsnip") + ) { msg <- c( msg, "i" = "Did you mistakenly pass `model_function` rather than `model_function()`?" @@ -176,10 +194,15 @@ stop_missing_engine <- function(cls, call) { info <- get_from_env(cls) |> dplyr::group_by(mode) |> - dplyr::summarize(msg = paste0(unique(mode), " {", - paste0(unique(engine), collapse = ", "), - "}"), - .groups = "drop") + dplyr::summarize( + msg = paste0( + unique(mode), + " {", + paste0(unique(engine), collapse = ", "), + "}" + ), + .groups = "drop" + ) if (nrow(info) == 0) { cli::cli_abort("No known engines for {.fn {cls}}.", call = call) } @@ -203,7 +226,6 @@ check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) { # check if class and mode and engine are compatible check_spec_mode_engine_val <- function(cls, eng, mode, call = caller_env()) { - all_modes <- get_from_env(paste0(cls, "_modes")) if (!(mode %in% all_modes)) { cli::cli_abort( @@ -218,7 +240,10 @@ check_spec_mode_engine_val <- function(cls, eng, mode, call = caller_env()) { # parsnip model environment. If so, return early. # If not, troubleshoot more precisely and raise a relevant error. model_env_match <- - vctrs::vec_slice(model_info, model_info$engine == eng & model_info$mode == mode) + vctrs::vec_slice( + model_info, + model_info$engine == eng & model_info$mode == mode + ) if (vctrs::vec_size(model_env_match) == 1) { return(invisible(NULL)) @@ -302,7 +327,7 @@ check_func_val <- function(func, call = caller_env()) { nms <- sort(names(func)) - if (all(is.null(nms))) { + if (all(is.null(nms))) { cli::cli_abort(msg, call = call) } @@ -379,7 +404,10 @@ check_fit_info <- function(fit_obj, call = caller_env()) { check_func_val(fit_obj$func) if (!is.list(fit_obj$defaults)) { - cli::cli_abort("The {.field defaults} element should be a list.", call = call) + cli::cli_abort( + "The {.field defaults} element should be a list.", + call = call + ) } invisible(NULL) @@ -560,10 +588,17 @@ set_new_model <- function(model) { current <- get_model_env() set_env_val("models", unique(c(current$models, model))) - set_env_val(model, tibble::new_tibble(list(engine = character(0), mode = character(0)))) + set_env_val( + model, + tibble::new_tibble(list(engine = character(0), mode = character(0))) + ) set_env_val( paste0(model, "_pkgs"), - tibble::new_tibble(list(engine = character(0), pkg = list(), mode = character(0))) + tibble::new_tibble(list( + engine = character(0), + pkg = list(), + mode = character(0) + )) ) set_env_val(paste0(model, "_modes"), "unknown") set_env_val( @@ -659,13 +694,16 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { old_args <- get_from_env(paste0(model, "_args")) new_arg <- - tibble::new_tibble(list( - engine = eng, - parsnip = parsnip, - original = original, - func = list(func), - has_submodel = has_submodel - ), nrow = 1) + tibble::new_tibble( + list( + engine = eng, + parsnip = parsnip, + original = original, + func = list(func), + has_submodel = has_submodel + ), + nrow = 1 + ) updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) if (inherits(updated, "try-error")) { @@ -774,9 +812,15 @@ get_dependency <- function(model) { # This will be used to see if the same information is being registered for the # same model/mode/engine (and prediction type). If it already exists and the # new information is different, fail with a message. See issue #653 -is_discordant_info <- function(model, mode, eng, candidate, - pred_type = NULL, component = "fit", - call = caller_env()) { +is_discordant_info <- function( + model, + mode, + eng, + candidate, + pred_type = NULL, + component = "fit", + call = caller_env() +) { current <- get_from_env(paste0(model, "_", component)) # For older versions of parsnip before set_encoding() @@ -785,11 +829,10 @@ is_discordant_info <- function(model, mode, eng, candidate, if (new_encoding) { return(TRUE) } else { - current <- dplyr::filter(current, engine == eng & mode == !!mode) + current <- dplyr::filter(current, engine == eng & mode == !!mode) } if (component == "predict" & !is.null(pred_type)) { - current <- dplyr::filter(current, type == pred_type) p_type <- "and prediction type {.val {pred_type}} " } else { @@ -836,7 +879,6 @@ check_unregistered <- function(model, mode, eng, call = caller_env()) { } - #' @rdname set_new_model #' @keywords internal #' @export @@ -848,11 +890,14 @@ set_fit <- function(model, mode, eng, value) { check_unregistered(model, mode, eng) new_fit <- - tibble::new_tibble(list( - engine = eng, - mode = mode, - value = list(value) - ), nrow = 1) + tibble::new_tibble( + list( + engine = eng, + mode = mode, + value = list(value) + ), + nrow = 1 + ) if (!is_discordant_info(model, mode, eng, new_fit)) { return(invisible(NULL)) @@ -882,7 +927,9 @@ get_fit <- function(model) { check_model_exists(model) fit_name <- paste0(model, "_fit") if (!any(fit_name != rlang::env_names(get_model_env()))) { - cli::cli_abort("{.arg {model}} does not have a {.fn fit} method in parsnip.") + cli::cli_abort( + "{.arg {model}} does not have a {.fn fit} method in parsnip." + ) } rlang::env_get(get_model_env(), fit_name) } @@ -907,7 +954,14 @@ set_pred <- function(model, mode, eng, type, value) { nrow = 1 ) - pred_check <- is_discordant_info(model, mode, eng, new_pred, pred_type = type, component = "predict") + pred_check <- is_discordant_info( + model, + mode, + eng, + new_pred, + pred_type = type, + component = "predict" + ) if (!pred_check) { return(invisible(NULL)) } @@ -939,7 +993,9 @@ get_pred_type <- function(model, type) { } all_preds <- rlang::env_get(get_model_env(), pred_name) if (!any(all_preds$type == type)) { - cli::cli_abort("{.arg {model}} does not have any prediction methods in parsnip.") + cli::cli_abort( + "{.arg {model}} does not have any prediction methods in parsnip." + ) } dplyr::filter(all_preds, type == !!type) } @@ -978,8 +1034,8 @@ show_model_info <- function(model) { ) |> dplyr::select(engine, mode, has_wts) - engine_weight_info <- engines |> - dplyr::left_join(weight_info, by = c("engine", "mode")) |> + engine_weight_info <- engines |> + dplyr::left_join(weight_info, by = c("engine", "mode")) |> dplyr::mutate( engine = paste0(engine, has_wts), mode = format(paste0(mode, ": ")) @@ -1018,7 +1074,11 @@ show_model_info <- function(model) { dplyr::group_by(engine) |> dplyr::mutate( engine2 = ifelse(dplyr::row_number() == 1, engine, ""), - parsnip = ifelse(dplyr::row_number() == 1, paste0("\n", parsnip), parsnip), + parsnip = ifelse( + dplyr::row_number() == 1, + paste0("\n", parsnip), + parsnip + ), lab = paste0(engine2, parsnip) ) |> dplyr::ungroup() |> @@ -1064,7 +1124,7 @@ show_model_info <- function(model) { #' @rdname set_new_model #' @keywords internal #' @export -pred_value_template <- function(pre = NULL, post = NULL, func, ...) { +pred_value_template <- function(pre = NULL, post = NULL, func, ...) { if (rlang::is_missing(func)) { cli::cli_abort( "Please supply a value to {.arg func}. See {.help [{.fun set_pred}](parsnip::set_pred)}." @@ -1079,10 +1139,12 @@ check_encodings <- function(x, call = caller_env()) { if (!is.list(x)) { cli::cli_abort("{.arg values} should be a list.", call = call) } - req_args <- list(predictor_indicators = rlang::na_chr, - compute_intercept = rlang::na_lgl, - remove_intercept = rlang::na_lgl, - allow_sparse_x = rlang::na_lgl) + req_args <- list( + predictor_indicators = rlang::na_chr, + compute_intercept = rlang::na_lgl, + remove_intercept = rlang::na_lgl, + allow_sparse_x = rlang::na_lgl + ) missing_args <- setdiff(names(req_args), names(x)) if (length(missing_args) > 0) { @@ -1112,11 +1174,20 @@ set_encoding <- function(model, mode, eng, options) { check_mode_val(mode) check_encodings(options) - keys <- tibble::new_tibble(list(model = model, engine = eng, mode = mode), nrow = 1) + keys <- tibble::new_tibble( + list(model = model, engine = eng, mode = mode), + nrow = 1 + ) options <- tibble::as_tibble(options) new_values <- dplyr::bind_cols(keys, options) - enc_check <- is_discordant_info(model, mode, eng, new_values, component = "encoding") + enc_check <- is_discordant_info( + model, + mode, + eng, + new_values, + component = "encoding" + ) if (!enc_check) { return(invisible(NULL)) } @@ -1150,8 +1221,14 @@ get_encoding <- function(model) { remove_intercept = TRUE, allow_sparse_x = FALSE ) |> - dplyr::select(model, engine, mode, predictor_indicators, - compute_intercept, remove_intercept) + dplyr::select( + model, + engine, + mode, + predictor_indicators, + compute_intercept, + remove_intercept + ) } res } diff --git a/R/aaa_multi_predict.R b/R/aaa_multi_predict.R index 8f9b67be9..683990152 100644 --- a/R/aaa_multi_predict.R +++ b/R/aaa_multi_predict.R @@ -121,8 +121,8 @@ multi_predict_args.default <- function(object, ...) { multi_predict_args.model_fit <- function(object, ...) { model_type <- class(object$spec)[1] arg_info <- get_from_env(paste0(model_type, "_args")) - arg_info <- arg_info[arg_info$engine == object$spec$engine,] - arg_info <- arg_info[arg_info$has_submodel,] + arg_info <- arg_info[arg_info$engine == object$spec$engine, ] + arg_info <- arg_info[arg_info$has_submodel, ] if (nrow(arg_info) == 0) { res <- NA_character_ diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 45a1942ed..254d8d6b3 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -13,5 +13,7 @@ matrix_to_quantile_pred <- function(x, object) { n_pred_quantiles <- ncol(x) quantile_levels <- object$spec$quantile_levels - tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels))) + tibble::new_tibble( + x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)) + ) } diff --git a/R/adds.R b/R/adds.R index 53d43e7df..c991907b3 100644 --- a/R/adds.R +++ b/R/adds.R @@ -12,4 +12,3 @@ add_rowindex <- function(x) { x <- dplyr::mutate(x, .row = seq_len(nrow(x))) x } - diff --git a/R/arguments.R b/R/arguments.R index b35a5338f..df67ac8dd 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -61,8 +61,9 @@ set_args <- function(object, ...) { #' @export set_args.model_spec <- function(object, ...) { the_dots <- enquos(...) - if (length(the_dots) == 0) + if (length(the_dots) == 0) { cli::cli_abort("Please pass at least one named argument.") + } main_args <- names(object$args) new_args <- names(the_dots) for (i in new_args) { @@ -85,7 +86,7 @@ set_args.model_spec <- function(object, ...) { } #' @export -set_args.default <- function(object,...) { +set_args.default <- function(object, ...) { error_set_object(object, func = "set_args") invisible(FALSE) @@ -114,19 +115,22 @@ set_mode.model_spec <- function(object, mode, quantile_levels = NULL, ...) { # determine if the model specification could feasibly match any entry # in the union of the parsnip model environment and model_info_table. # if not, trigger an error based on the (possibly inferred) model spec slots. - if (!spec_is_possible(spec = object, - mode = mode, user_specified_mode = TRUE)) { + if ( + !spec_is_possible(spec = object, mode = mode, user_specified_mode = TRUE) + ) { check_spec_mode_engine_val(cls, object$engine, mode) } object$mode <- mode object$user_specified_mode <- TRUE if (mode == "quantile regression") { - hardhat::check_quantile_levels(quantile_levels) + hardhat::check_quantile_levels(quantile_levels) } else { if (!is.null(quantile_levels)) { - cli::cli_warn("{.arg quantile_levels} is only used when the mode is - {.val quantile regression}.") + cli::cli_warn( + "{.arg quantile_levels} is only used when the mode is + {.val quantile regression}." + ) } } @@ -146,8 +150,9 @@ set_mode.default <- function(object, mode, ...) { maybe_eval <- function(x) { # if descriptors are in `x`, eval fails y <- try(rlang::eval_tidy(x), silent = TRUE) - if (inherits(y, "try-error")) + if (inherits(y, "try-error")) { y <- x + } y } @@ -157,7 +162,7 @@ maybe_eval <- function(x) { #' @param spec A [model specification][model_spec]. #' @param ... Not used. eval_args <- function(spec, ...) { - spec$args <- purrr::map(spec$args, maybe_eval) + spec$args <- purrr::map(spec$args, maybe_eval) spec$eng_args <- purrr::map(spec$eng_args, maybe_eval) spec } @@ -192,14 +197,20 @@ eval_args <- function(spec, ...) { make_call <- function(fun, ns, args, ...) { # remove any null or placeholders (`missing_args`) that remain discard <- - vapply(args, function(x) - is_missing_arg(x) | is.null(x), logical(1)) + vapply( + args, + function(x) { + is_missing_arg(x) | is.null(x) + }, + logical(1) + ) args <- args[!discard] if (!is.null(ns) & !is.na(ns)) { out <- call2(fun, !!!args, .ns = ns) - } else + } else { out <- call2(fun, !!!args) + } out } @@ -240,11 +251,11 @@ make_form_call <- function(object, env = NULL) { # add data arguments for (i in seq_along(data_args)) { - fit_args[[ unname(data_args[i]) ]] <- sym(names(data_args)[i]) + fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i]) } # sub in actual formula - fit_args[[ unname(data_args["formula"]) ]] <- env$formula + fit_args[[unname(data_args["formula"])]] <- env$formula # TODO remove weights col from data? if (object$engine == "spark") { @@ -277,8 +288,8 @@ make_xy_call <- function(object, target, env, call = rlang::caller_env()) { data_args <- object$method$fit$data } - object$method$fit$args[[ unname(data_args["y"]) ]] <- rlang::expr(y) - object$method$fit$args[[ unname(data_args["x"]) ]] <- + object$method$fit$args[[unname(data_args["y"])]] <- rlang::expr(y) + object$method$fit$args[[unname(data_args["x"])]] <- switch( target, none = rlang::expr(x), @@ -288,7 +299,9 @@ make_xy_call <- function(object, target, env, call = rlang::caller_env()) { cli::cli_abort("Invalid data type target: {target}.", call = call) ) if (uses_weights) { - object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights) + object$method$fit$args[[unname(data_args["weights"])]] <- rlang::expr( + weights + ) } fit_call <- make_call( diff --git a/R/auto_ml.R b/R/auto_ml.R index 4b40c4a2b..57818717b 100644 --- a/R/auto_ml.R +++ b/R/auto_ml.R @@ -26,8 +26,13 @@ #' @export auto_ml <- function(mode = "unknown", engine = "h2o") { args <- list() - out <- list(args = args, eng_args = NULL, - mode = mode, method = NULL, engine = engine) + out <- list( + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = engine + ) class(out) <- make_classes("auto_ml") out } diff --git a/R/autoplot.R b/R/autoplot.R index 685d1f067..e53fac924 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -32,8 +32,13 @@ autoplot.model_fit <- function(object, ...) { #' @export #' @rdname autoplot.model_fit -autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, - top_n = 3L) { +autoplot.glmnet <- function( + object, + ..., + min_penalty = 0, + best_penalty = NULL, + top_n = 3L +) { check_number_decimal(min_penalty, min = 0, max = 1) check_number_decimal(best_penalty, min = 0, max = 1, allow_null = TRUE) check_number_whole(top_n, min = 1, max = Inf, allow_infinite = TRUE) @@ -68,13 +73,18 @@ reformat_coefs <- function(x, p, penalty) { num_estimates <- nrow(x) if (num_estimates > p) { # The intercept is first - x <- x[-(num_estimates - p),, drop = FALSE] + x <- x[-(num_estimates - p), , drop = FALSE] } term_lab <- rownames(x) colnames(x) <- paste(seq_along(penalty)) x <- tibble::as_tibble(x) x$term <- term_lab - x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate") + x <- tidyr::pivot_longer( + x, + cols = -term, + names_to = "index", + values_to = "estimate" + ) x$penalty <- rep(penalty, p) x$index <- NULL x @@ -90,8 +100,14 @@ top_coefs <- function(x, top_n = 5) { dplyr::slice(seq_len(top_n)) } -autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, - call = rlang::caller_env(), ...) { +autoplot_glmnet <- function( + x, + min_penalty = 0, + best_penalty = NULL, + top_n = 3L, + call = rlang::caller_env(), + ... +) { tidy_coefs <- map_glmnet_coefs(x, call = call) |> dplyr::filter(penalty >= min_penalty) @@ -134,10 +150,15 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, # plot the paths and highlight the large values p <- tidy_coefs |> - ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term)) + ggplot2::ggplot(ggplot2::aes( + x = penalty, + y = estimate, + group = term, + col = term + )) if (has_groups) { - p <- p + ggplot2::facet_wrap(~ class) + p <- p + ggplot2::facet_wrap(~class) } if (!is.null(best_penalty)) { @@ -148,7 +169,7 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ggplot2::geom_line(alpha = .4, show.legend = FALSE) + ggplot2::scale_x_log10() - if(top_n > 0) { + if (top_n > 0) { rlang::check_installed("ggrepel") p <- p + ggrepel::geom_label_repel( diff --git a/R/bag_mars.R b/R/bag_mars.R index b22722868..03fa34b9a 100644 --- a/R/bag_mars.R +++ b/R/bag_mars.R @@ -22,15 +22,17 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("bag_mars")} #' @export bag_mars <- - function(mode = "unknown", - num_terms = NULL, - prod_degree = NULL, - prune_method = NULL, - engine = "earth") { + function( + mode = "unknown", + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL, + engine = "earth" + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), - prune_method = enquo(prune_method) + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), + prune_method = enquo(prune_method) ) new_model_spec( @@ -52,15 +54,19 @@ bag_mars <- #' @inheritParams mars #' @export update.bag_mars <- - function(object, - parameters = NULL, - num_terms = NULL, prod_degree = NULL, prune_method = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL, + fresh = FALSE, + ... + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), - prune_method = enquo(prune_method) + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), + prune_method = enquo(prune_method) ) update_spec( diff --git a/R/bag_mlp.R b/R/bag_mlp.R index aa7ea4b30..d88f8c06e 100644 --- a/R/bag_mlp.R +++ b/R/bag_mlp.R @@ -20,15 +20,17 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("bag_mlp")} #' @export bag_mlp <- - function(mode = "unknown", - hidden_units = NULL, - penalty = NULL, - epochs = NULL, - engine = "nnet") { + function( + mode = "unknown", + hidden_units = NULL, + penalty = NULL, + epochs = NULL, + engine = "nnet" + ) { args <- list( - hidden_units = enquo(hidden_units), - penalty = enquo(penalty), - epochs = enquo(epochs) + hidden_units = enquo(hidden_units), + penalty = enquo(penalty), + epochs = enquo(epochs) ) new_model_spec( @@ -50,15 +52,19 @@ bag_mlp <- #' @inheritParams mars #' @export update.bag_mlp <- - function(object, - parameters = NULL, - hidden_units = NULL, penalty = NULL, epochs = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + hidden_units = NULL, + penalty = NULL, + epochs = NULL, + fresh = FALSE, + ... + ) { args <- list( - hidden_units = enquo(hidden_units), - penalty = enquo(penalty), - epochs = enquo(epochs) + hidden_units = enquo(hidden_units), + penalty = enquo(penalty), + epochs = enquo(epochs) ) update_spec( diff --git a/R/bag_tree.R b/R/bag_tree.R index e80fc200a..6ab70c366 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -25,16 +25,18 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("bag_tree")} #' @export bag_tree <- - function(mode = "unknown", - cost_complexity = 0, - tree_depth = NULL, - min_n = 2, - class_cost = NULL, - engine = "rpart") { + function( + mode = "unknown", + cost_complexity = 0, + tree_depth = NULL, + min_n = 2, + class_cost = NULL, + engine = "rpart" + ) { args <- list( - cost_complexity = enquo(cost_complexity), - tree_depth = enquo(tree_depth), - min_n = enquo(min_n), + cost_complexity = enquo(cost_complexity), + tree_depth = enquo(tree_depth), + min_n = enquo(min_n), class_cost = enquo(class_cost) ) @@ -58,17 +60,21 @@ bag_tree <- #' @inheritParams bag_tree #' @export update.bag_tree <- - function(object, - parameters = NULL, - cost_complexity = NULL, tree_depth = NULL, min_n = NULL, - class_cost = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + cost_complexity = NULL, + tree_depth = NULL, + min_n = NULL, + class_cost = NULL, + fresh = FALSE, + ... + ) { args <- list( cost_complexity = enquo(cost_complexity), - tree_depth = enquo(tree_depth), - min_n = enquo(min_n), - class_cost = enquo(class_cost) + tree_depth = enquo(tree_depth), + min_n = enquo(min_n), + class_cost = enquo(class_cost) ) update_spec( diff --git a/R/bart.R b/R/bart.R index 7130a20ad..0e6b1fdb1 100644 --- a/R/bart.R +++ b/R/bart.R @@ -70,11 +70,14 @@ #' @export bart <- - function(mode = "unknown", engine = "dbarts", - trees = NULL, prior_terminal_node_coef = NULL, - prior_terminal_node_expo = NULL, - prior_outcome_range = NULL) { - + function( + mode = "unknown", + engine = "dbarts", + trees = NULL, + prior_terminal_node_coef = NULL, + prior_terminal_node_expo = NULL, + prior_outcome_range = NULL + ) { args <- list( trees = enquo(trees), prior_terminal_node_coef = enquo(prior_terminal_node_coef), @@ -105,14 +108,16 @@ bart <- #' a node is a terminal node. #' @export update.bart <- - function(object, - parameters = NULL, - trees = NULL, - prior_terminal_node_coef = NULL, - prior_terminal_node_expo = NULL, - prior_outcome_range = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + trees = NULL, + prior_terminal_node_coef = NULL, + prior_terminal_node_expo = NULL, + prior_outcome_range = NULL, + fresh = FALSE, + ... + ) { args <- list( trees = enquo(trees), prior_terminal_node_coef = enquo(prior_terminal_node_coef), @@ -138,11 +143,17 @@ update.bart <- #' @param std_err Attach column for standard error of prediction or not. #' @export #' @keywords internal -dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) { +dbart_predict_calc <- function( + obj, + new_data, + type, + level = 0.95, + std_err = FALSE +) { types <- c("numeric", "class", "prob", "conf_int", "pred_int") mod_mode <- obj$spec$mode - lo <- (1 - level)/2 - hi <- 1 - lo + lo <- (1 - level) / 2 + hi <- 1 - lo if (type == "conf_int") { post_dist <- predict(obj$fit, new_data, type = "ev") @@ -175,10 +186,10 @@ dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALS res <- tibble::tibble( - .pred_lower_a = 1 - bnds[,2], - .pred_lower_b = bnds[,1], - .pred_upper_a = 1 - bnds[,1], - .pred_upper_b = bnds[,2] + .pred_lower_a = 1 - bnds[, 2], + .pred_lower_b = bnds[, 1], + .pred_upper_a = 1 - bnds[, 1], + .pred_upper_b = bnds[, 2] ) |> rlang::set_names( c( @@ -193,4 +204,3 @@ dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALS } res } - diff --git a/R/bart_data.R b/R/bart_data.R index 0d2cde864..f474e459d 100644 --- a/R/bart_data.R +++ b/R/bart_data.R @@ -1,4 +1,3 @@ - set_new_model("bart") set_model_mode("bart", "classification") @@ -101,12 +100,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + obj = quote(object), + new_data = quote(new_data), + type = "numeric" + ) ) ) @@ -119,9 +117,7 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list(obj = quote(object), - new_data = quote(new_data)) + args = list(obj = quote(object), new_data = quote(new_data)) ) ) @@ -135,14 +131,13 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = expr(object), - new_data = expr(new_data), - type = "conf_int", - level = expr(level), - std_err = expr(std_error) - ) + args = list( + obj = expr(object), + new_data = expr(new_data), + type = "conf_int", + level = expr(level), + std_err = expr(std_error) + ) ) ) set_pred( @@ -154,14 +149,13 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = expr(object), - new_data = expr(new_data), - type = "pred_int", - level = expr(level), - std_err = expr(std_error) - ) + args = list( + obj = expr(object), + new_data = expr(new_data), + type = "pred_int", + level = expr(level), + std_err = expr(std_error) + ) ) ) @@ -175,12 +169,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data), - type = "class" - ) + args = list( + obj = quote(object), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -193,12 +186,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data), - type = "prob" - ) + args = list( + obj = quote(object), + new_data = quote(new_data), + type = "prob" + ) ) ) @@ -212,14 +204,13 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = expr(object), - new_data = expr(new_data), - type = "conf_int", - level = expr(level), - std_err = expr(std_error) - ) + args = list( + obj = expr(object), + new_data = expr(new_data), + type = "conf_int", + level = expr(level), + std_err = expr(std_error) + ) ) ) set_pred( @@ -231,14 +222,13 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = expr(object), - new_data = expr(new_data), - type = "pred_int", - level = expr(level), - std_err = expr(std_error) - ) + args = list( + obj = expr(object), + new_data = expr(new_data), + type = "pred_int", + level = expr(level), + std_err = expr(std_error) + ) ) ) @@ -251,10 +241,9 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data) - ) + args = list( + obj = quote(object), + new_data = quote(new_data) + ) ) ) diff --git a/R/boost_tree.R b/R/boost_tree.R index 8e67bf999..dc0e288d0 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -53,13 +53,18 @@ #' boost_tree(mode = "classification", trees = 20) #' @export boost_tree <- - function(mode = "unknown", - engine = "xgboost", - mtry = NULL, trees = NULL, min_n = NULL, - tree_depth = NULL, learn_rate = NULL, - loss_reduction = NULL, - sample_size = NULL, - stop_iter = NULL) { + function( + mode = "unknown", + engine = "xgboost", + mtry = NULL, + trees = NULL, + min_n = NULL, + tree_depth = NULL, + learn_rate = NULL, + loss_reduction = NULL, + sample_size = NULL, + stop_iter = NULL + ) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -89,14 +94,20 @@ boost_tree <- #' @rdname parsnip_update #' @export update.boost_tree <- - function(object, - parameters = NULL, - mtry = NULL, trees = NULL, min_n = NULL, - tree_depth = NULL, learn_rate = NULL, - loss_reduction = NULL, sample_size = NULL, - stop_iter = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + mtry = NULL, + trees = NULL, + min_n = NULL, + tree_depth = NULL, + learn_rate = NULL, + loss_reduction = NULL, + sample_size = NULL, + stop_iter = NULL, + fresh = FALSE, + ... + ) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -149,7 +160,11 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { # min_n parameters if (any(names(arg_vals) == "min_instances_per_node")) { arg_vals$min_instances_per_node <- - rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x)) + rlang::call2( + "min_rows", + rlang::eval_tidy(arg_vals$min_instances_per_node), + expr(x) + ) } ## ----------------------------------------------------------------------------- @@ -163,13 +178,37 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { #' @export check_args.boost_tree <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees") - check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size") - check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth") - check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n") + check_number_whole( + args$trees, + min = 0, + allow_null = TRUE, + call = call, + arg = "trees" + ) + check_number_decimal( + args$sample_size, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "sample_size" + ) + check_number_whole( + args$tree_depth, + min = 0, + allow_null = TRUE, + call = call, + arg = "tree_depth" + ) + check_number_whole( + args$min_n, + min = 0, + allow_null = TRUE, + call = call, + arg = "min_n" + ) invisible(object) } @@ -215,12 +254,23 @@ check_args.boost_tree <- function(object, call = rlang::caller_env()) { #' @keywords internal #' @export xgb_train <- function( - x, y, weights = NULL, - max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL, - colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1, - validation = 0, early_stop = NULL, counts = TRUE, - event_level = c("first", "second"), ...) { - + x, + y, + weights = NULL, + max_depth = 6, + nrounds = 15, + eta = 0.3, + colsample_bynode = NULL, + colsample_bytree = NULL, + min_child_weight = 1, + gamma = 0, + subsample = 1, + validation = 0, + early_stop = NULL, + counts = TRUE, + event_level = c("first", "second"), + ... +) { event_level <- rlang::arg_match(event_level, c("first", "second")) others <- list(...) @@ -243,11 +293,13 @@ xgb_train <- function( p <- ncol(x) x <- - as_xgb_data(x, y, - validation = validation, - event_level = event_level, - weights = weights) - + as_xgb_data( + x, + y, + validation = validation, + event_level = event_level, + weights = weights + ) if (!is.numeric(subsample) || subsample < 0 || subsample > 1) { cli::cli_abort("{.arg subsample} should be on [0, 1].") @@ -268,9 +320,9 @@ xgb_train <- function( if (min_child_weight > n) { cli::cli_warn( c( - "!" = "{min_child_weight} samples were requested but there were {n} rows + "!" = "{min_child_weight} samples were requested but there were {n} rows in the data.", - "i" = "{n} will be used." + "i" = "{n} will be used." ) ) min_child_weight <- min(min_child_weight, n) @@ -289,14 +341,14 @@ xgb_train <- function( others <- process_others(others, arg_list) main_args <- c( - list( - data = quote(x$data), - watchlist = quote(x$watchlist), - params = arg_list, - nrounds = nrounds, - early_stopping_rounds = early_stop - ), - others + list( + data = quote(x$data), + watchlist = quote(x$watchlist), + params = arg_list, + nrounds = nrounds, + early_stopping_rounds = early_stop + ), + others ) if (is.null(main_args$objective)) { @@ -365,7 +417,7 @@ recalc_param <- function(x, counts, denom) { } else { if (counts) { maybe_proportion(x, nm) - x <- min(denom, x)/denom + x <- min(denom, x) / denom } } x @@ -399,14 +451,26 @@ xgb_predict <- function(object, new_data, ...) { x <- switch( object$params$objective %||% 3L, "binary:logitraw" = stats::binomial()$linkinv(res), - "multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE), - res) + "multi:softprob" = matrix( + res, + ncol = object$params$num_class, + byrow = TRUE + ), + res + ) x } -as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) { +as_xgb_data <- function( + x, + y, + validation = 0, + weights = NULL, + event_level = "first", + ... +) { lvls <- levels(y) n <- nrow(x) @@ -450,7 +514,6 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir missing = NA, info = info_list ) - } else { info_list <- list(label = y) if (!is.null(weights)) { @@ -470,7 +533,7 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir list(data = dat, watchlist = watch_list) } -get_event_level <- function(model_spec){ +get_event_level <- function(model_spec) { if ("event_level" %in% names(model_spec$eng_args)) { event_level <- get_expr(model_spec$eng_args$event_level) } else { @@ -492,15 +555,22 @@ multi_predict._xgb.Booster <- trees <- sort(trees) if (is.null(type)) { - if (object$spec$mode == "classification") + if (object$spec$mode == "classification") { type <- "class" - else + } else { type <- "numeric" + } } res <- - map(trees, xgb_by_tree, object = object, new_data = new_data, - type = type, ...) |> + map( + trees, + xgb_by_tree, + object = object, + new_data = new_data, + type = type, + ... + ) |> purrr::list_rbind() res <- arrange(res, .row, trees) res <- split(res[, -1], res$.row) @@ -582,10 +652,8 @@ C5.0_train <- cli::cli_abort("There are zero rows in the predictor set.") } - ctrl <- call2("C5.0Control", .ns = "C50") if (minCases > n) { - cli::cli_warn( c( "!" = "{minCases} samples were requested but there were {n} rows in the data.", @@ -616,16 +684,24 @@ C5.0_train <- #' @rdname multi_predict multi_predict._C5.0 <- function(object, new_data, type = NULL, trees = NULL, ...) { - if (is.null(trees)) + if (is.null(trees)) { trees <- min(object$fit$trials) + } trees <- sort(trees) - if (is.null(type)) + if (is.null(type)) { type <- "class" + } res <- - map(trees, C50_by_tree, object = object, - new_data = new_data, type = type, ...) |> + map( + trees, + C50_by_tree, + object = object, + new_data = new_data, + type = type, + ... + ) |> purrr::list_rbind() res <- arrange(res, .row, trees) res <- split(res[, -1], res$.row) @@ -648,5 +724,3 @@ C50_by_tree <- function(tree, object, new_data, type, ...) { pred[[".row"]] <- seq_len(nrow(new_data)) pred[, c(".row", "trees", nms)] } - - diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 2ec5b03bd..0a457d78e 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -292,12 +292,11 @@ set_pred( as_tibble(x) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) ) ) @@ -310,8 +309,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -387,7 +385,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula", "type"), func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), - defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) @@ -412,7 +410,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula", "type"), func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), - defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) diff --git a/R/c5_rules.R b/R/c5_rules.R index 904e2cb56..b794983c3 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -45,11 +45,12 @@ #' C5_rules() #' @export C5_rules <- - function(mode = "classification", - trees = NULL, - min_n = NULL, - engine = "C5.0") { - + function( + mode = "classification", + trees = NULL, + min_n = NULL, + engine = "C5.0" + ) { args <- list( trees = enquo(trees), min_n = enquo(min_n) @@ -86,11 +87,14 @@ C5_rules <- #' @inheritParams C5_rules #' @export update.C5_rules <- - function(object, - parameters = NULL, - trees = NULL, min_n = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + trees = NULL, + min_n = NULL, + fresh = FALSE, + ... + ) { args <- list( trees = enquo(trees), min_n = enquo(min_n) @@ -112,7 +116,6 @@ update.C5_rules <- #' @export check_args.C5_rules <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n") @@ -135,4 +138,3 @@ check_args.C5_rules <- function(object, call = rlang::caller_env()) { set_new_model("C5_rules") set_model_mode("C5_rules", "classification") - diff --git a/R/case_weights.R b/R/case_weights.R index 53c2f0847..d26bd752d 100644 --- a/R/case_weights.R +++ b/R/case_weights.R @@ -47,9 +47,11 @@ weights_to_numeric <- function(x, spec) { x } -patch_formula_environment_with_case_weights <- function(formula, - data, - case_weights) { +patch_formula_environment_with_case_weights <- function( + formula, + data, + case_weights +) { # `lm()` and `glm()` and others use the original model function call to # construct a call for `model.frame()`. That will normally fail because the # formula has its own environment attached (usually the global environment) diff --git a/R/control_parsnip.R b/R/control_parsnip.R index bdb9aaeb1..b83dd1f26 100644 --- a/R/control_parsnip.R +++ b/R/control_parsnip.R @@ -29,14 +29,16 @@ control_parsnip <- function(verbosity = 1L, catch = FALSE) { } check_control <- function(x, call = rlang::caller_env()) { - if (!is.list(x)) + if (!is.list(x)) { cli::cli_abort("{.arg control} should be a named list.", call = call) - if (!isTRUE(all.equal(sort(names(x)), c("catch", "verbosity")))) + } + if (!isTRUE(all.equal(sort(names(x)), c("catch", "verbosity")))) { cli::cli_abort( "{.arg control} should be a named list with elements {.field verbosity} and {.field catch}.", call = call ) + } check_number_whole(x$verbosity, call = call) check_bool(x$catch, call = call) x @@ -45,9 +47,11 @@ check_control <- function(x, call = rlang::caller_env()) { #' @export print.control_parsnip <- function(x, ...) { cat("parsnip control object\n") - if (x$verbosity > 1) + if (x$verbosity > 1) { cat(" - verbose level", x$verbosity, "\n") - if (x$catch) + } + if (x$catch) { cat(" - fit errors will be caught\n") + } invisible(x) } diff --git a/R/convert_data.R b/R/convert_data.R index 98615cab0..a3534d695 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -34,14 +34,16 @@ #' @keywords internal #' @export #' -.convert_form_to_xy_fit <- function(formula, - data, - ..., - na.action = na.omit, - indicators = "traditional", - composition = "data.frame", - remove_intercept = TRUE, - call = rlang::caller_env()) { +.convert_form_to_xy_fit <- function( + formula, + data, + ..., + na.action = na.omit, + indicators = "traditional", + composition = "data.frame", + remove_intercept = TRUE, + call = rlang::caller_env() +) { if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) { cli::cli_abort( "{.arg composition} should be either {.val data.frame}, {.val matrix}, or @@ -175,11 +177,13 @@ #' @rdname convert_helpers #' @keywords internal #' @export -.convert_form_to_xy_new <- function(object, - new_data, - na.action = na.pass, - composition = "data.frame", - call = rlang::caller_env()) { +.convert_form_to_xy_new <- function( + object, + new_data, + na.action = na.pass, + composition = "data.frame", + call = rlang::caller_env() +) { if (!(composition %in% c("data.frame", "matrix"))) { cli::cli_abort( "{.arg composition} should be either {.val data.frame} or {.val matrix}.", @@ -245,12 +249,14 @@ #' @keywords internal #' @export #' -.convert_xy_to_form_fit <- function(x, - y, - weights = NULL, - y_name = "..y", - remove_intercept = TRUE, - call = rlang::caller_env()) { +.convert_xy_to_form_fit <- function( + x, + y, + weights = NULL, + y_name = "..y", + remove_intercept = TRUE, + call = rlang::caller_env() +) { if (is.vector(x)) { cli::cli_abort("{.arg x} cannot be a vector.", call = call) } @@ -288,7 +294,10 @@ cli::cli_abort("{.arg weights} must be a numeric vector.", call = call) } if (length(weights) != nrow(x)) { - cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call) + cli::cli_abort( + "{.arg weights} should have {nrow(x)} elements.", + call = call + ) } form <- patch_formula_environment_with_case_weights( @@ -351,31 +360,41 @@ make_formula <- function(x, y, short = TRUE) { paste0(y, collapse = ","), ")~" ) - } else + } else { y_part <- paste0(y, "~") - if(short) + } + if (short) { form_text <- paste0(y_part, ".") - else + } else { form_text <- paste0(y_part, paste0(x, collapse = "+")) + } as.formula(form_text) } will_make_matrix <- function(y) { - if (is.matrix(y) | is.atomic(y)) + if (is.matrix(y) | is.atomic(y)) { return(FALSE) + } cls <- unique(unlist(lapply(y, class))) - if (length(cls) > 1) + if (length(cls) > 1) { return(FALSE) + } can_convert <- - vapply(y, function(x) - is.atomic(x) & !is.factor(x), logical(1)) + vapply( + y, + function(x) { + is.atomic(x) & !is.factor(x) + }, + logical(1) + ) all(can_convert) } check_dup_names <- function(x, y, call = rlang::caller_env()) { - if (is.vector(y)) + if (is.vector(y)) { return(invisible(NULL)) + } common_names <- intersect(colnames(x), colnames(y)) if (length(common_names) > 0) { diff --git a/R/cubist_rules.R b/R/cubist_rules.R index a8cc5407c..b30a4f95a 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -68,12 +68,13 @@ #' Kuhn M and Johnson K (2013). _Applied Predictive Modeling_. Springer. #' @export cubist_rules <- - function(mode = "regression", - committees = NULL, - neighbors = NULL, - max_rules = NULL, - engine = "Cubist") { - + function( + mode = "regression", + committees = NULL, + neighbors = NULL, + max_rules = NULL, + engine = "Cubist" + ) { args <- list( committees = enquo(committees), neighbors = enquo(neighbors), @@ -109,15 +110,19 @@ cubist_rules <- #' @inheritParams cubist_rules #' @export update.cubist_rules <- - function(object, - parameters = NULL, - committees = NULL, neighbors = NULL, max_rules = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + committees = NULL, + neighbors = NULL, + max_rules = NULL, + fresh = FALSE, + ... + ) { args <- list( committees = enquo(committees), - neighbors = enquo(neighbors), - max_rules = enquo(max_rules) + neighbors = enquo(neighbors), + max_rules = enquo(max_rules) ) update_spec( @@ -136,24 +141,33 @@ update.cubist_rules <- #' @export check_args.cubist_rules <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees") + check_number_whole( + args$committees, + allow_null = TRUE, + call = call, + arg = "committees" + ) msg <- "The number of committees should be {.code >= 1} and {.code <= 100}." if (!(is.null(args$committees)) && args$committees > 100) { object$args$committees <- rlang::new_quosure(100L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 100.")) - } + cli::cli_warn(c(msg, "Truncating to 100.")) + } if (!(is.null(args$committees)) && args$committees < 1) { object$args$committees <- rlang::new_quosure(1L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 1.")) + cli::cli_warn(c(msg, "Truncating to 1.")) } - check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors") + check_number_whole( + args$neighbors, + allow_null = TRUE, + call = call, + arg = "neighbors" + ) msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}." if (!(is.null(args$neighbors)) && args$neighbors > 9) { diff --git a/R/data.R b/R/data.R index 41b778416..3f39e8ec5 100644 --- a/R/data.R +++ b/R/data.R @@ -11,4 +11,3 @@ #' @examplesIf !parsnip:::is_cran_check() #' data(model_db) NULL - diff --git a/R/decision_tree.R b/R/decision_tree.R index 2d72ce0af..d0f6918d8 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -33,13 +33,17 @@ #' @export decision_tree <- - function(mode = "unknown", engine = "rpart", cost_complexity = NULL, - tree_depth = NULL, min_n = NULL) { - + function( + mode = "unknown", + engine = "rpart", + cost_complexity = NULL, + tree_depth = NULL, + min_n = NULL + ) { args <- list( - cost_complexity = enquo(cost_complexity), - tree_depth = enquo(tree_depth), - min_n = enquo(min_n) + cost_complexity = enquo(cost_complexity), + tree_depth = enquo(tree_depth), + min_n = enquo(min_n) ) new_model_spec( @@ -60,15 +64,19 @@ decision_tree <- #' @rdname parsnip_update #' @export update.decision_tree <- - function(object, - parameters = NULL, - cost_complexity = NULL, tree_depth = NULL, min_n = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + cost_complexity = NULL, + tree_depth = NULL, + min_n = NULL, + fresh = FALSE, + ... + ) { args <- list( - cost_complexity = enquo(cost_complexity), - tree_depth = enquo(tree_depth), - min_n = enquo(min_n) + cost_complexity = enquo(cost_complexity), + tree_depth = enquo(tree_depth), + min_n = enquo(min_n) ) update_spec( @@ -113,7 +121,11 @@ translate.decision_tree <- function(x, engine = x$engine, ...) { } if (any(names(arg_vals) == "min_instances_per_node")) { arg_vals$min_instances_per_node <- - rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x)) + rlang::call2( + "min_rows", + rlang::eval_tidy(arg_vals$min_instances_per_node), + expr(x) + ) } ## ----------------------------------------------------------------------------- diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index be3008502..533a714c8 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -121,12 +121,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) ) ) @@ -221,12 +220,11 @@ set_pred( as_tibble(x) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) ) ) @@ -240,8 +238,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -278,8 +275,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula"), func = c(pkg = "sparklyr", fun = "ml_decision_tree_regressor"), - defaults = - list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) @@ -304,8 +300,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula"), func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), - defaults = - list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) diff --git a/R/descriptors.R b/R/descriptors.R index a8a97fdf8..7a9845cf8 100644 --- a/R/descriptors.R +++ b/R/descriptors.R @@ -113,19 +113,24 @@ get_descr_form <- function(formula, data, call = rlang::caller_env()) { } get_descr_df <- function(formula, data, call = rlang::caller_env()) { - tmp_dat <- - .convert_form_to_xy_fit(formula, - data, - indicators = "none", - remove_intercept = TRUE, - call = call) + .convert_form_to_xy_fit( + formula, + data, + indicators = "none", + remove_intercept = TRUE, + call = call + ) - if(is.factor(tmp_dat$y)) { + if (is.factor(tmp_dat$y)) { .lvls <- function() { table(tmp_dat$y, dnn = NULL) } - } else .lvls <- function() { NA } + } else { + .lvls <- function() { + NA + } + } .preds <- function() { ncol(tmp_dat$x) @@ -176,17 +181,16 @@ get_descr_df <- function(formula, data, call = rlang::caller_env()) { } get_descr_spark <- function(formula, data) { - all_vars <- all.vars(formula) - if("." %in% all_vars){ + if ("." %in% all_vars) { tmpdata <- dplyr::collect(head(data, 1000)) f_terms <- stats::terms(formula, data = tmpdata) f_cols <- rownames(attr(f_terms, "factors")) } else { f_terms <- stats::terms(formula) f_cols <- rownames(attr(f_terms, "factors")) - term_data <- dplyr::select(data, !!! rlang::syms(f_cols)) + term_data <- dplyr::select(data, !!!rlang::syms(f_cols)) tmpdata <- dplyr::collect(head(term_data, 1000)) } @@ -195,23 +199,22 @@ get_descr_spark <- function(formula, data) { y_col <- f_cols[y_ind] classes <- purrr::map(tmpdata, class) - icats <- purrr::map_lgl(classes, ~.x == "character") + icats <- purrr::map_lgl(classes, ~ .x == "character") cats <- classes[icats] - cat_preds <- purrr::imap_lgl(cats, ~.y %in% f_term_labels) + cat_preds <- purrr::imap_lgl(cats, ~ .y %in% f_term_labels) cats <- cats[cat_preds] cat_levels <- imap( cats, - ~{ - p <- dplyr::group_by(data, !! rlang::sym(.y)) + ~ { + p <- dplyr::group_by(data, !!rlang::sym(.y)) p <- dplyr::summarise(p) dplyr::pull(p) } ) numeric_pred <- length(f_term_labels) - length(cat_levels) - - if(length(cat_levels) > 0){ - n_dummies <- purrr::map_dbl(cat_levels, ~length(.x) - 1) + if (length(cat_levels) > 0) { + n_dummies <- purrr::map_dbl(cat_levels, ~ length(.x) - 1) n_dummies <- sum(n_dummies) all_preds <- numeric_pred + n_dummies factor_pred <- length(cat_levels) @@ -225,36 +228,40 @@ get_descr_spark <- function(formula, data) { outs <- purrr::imap( out_cats, - ~{ - p <- dplyr::group_by(data, !! sym(.y)) + ~ { + p <- dplyr::group_by(data, !!sym(.y)) p <- dplyr::tally(p) dplyr::collect(p) } ) - if(length(outs) > 0){ + if (length(outs) > 0) { outs <- outs[[1]] - y_vals <- purrr::as_vector(outs[,2]) - names(y_vals) <- purrr::as_vector(outs[,1]) + y_vals <- purrr::as_vector(outs[, 2]) + names(y_vals) <- purrr::as_vector(outs[, 1]) y_vals <- y_vals[order(names(y_vals))] y_vals <- as.table(y_vals) - } else y_vals <- NA + } else { + y_vals <- NA + } obs <- dplyr::tally(data) |> dplyr::pull() - .cols <- function() all_preds + .cols <- function() all_preds .preds <- function() length(f_term_labels) - .obs <- function() obs - .lvls <- function() y_vals + .obs <- function() obs + .lvls <- function() y_vals .facts <- function() factor_pred - .x <- function() cli::cli_abort("Descriptor {.fn .x} not defined for Spark.") - .y <- function() cli::cli_abort("Descriptor {.fn .y} not defined for Spark.") - .dat <- function() cli::cli_abort("Descriptor {.fn .dat} not defined for Spark.") + .x <- function() cli::cli_abort("Descriptor {.fn .x} not defined for Spark.") + .y <- function() cli::cli_abort("Descriptor {.fn .y} not defined for Spark.") + .dat <- function() { + cli::cli_abort("Descriptor {.fn .dat} not defined for Spark.") + } # still need .x(), .y(), .dat() ? list( - .cols = .cols, + .cols = .cols, .preds = .preds, .obs = .obs, .lvls = .lvls, @@ -266,14 +273,13 @@ get_descr_spark <- function(formula, data) { } get_descr_xy <- function(x, y, call = rlang::caller_env()) { - .lvls <- if (is.factor(y)) { function() table(y, dnn = NULL) } else { function() NA } - .cols <- function() { + .cols <- function() { ncol(x) } @@ -281,15 +287,16 @@ get_descr_xy <- function(x, y, call = rlang::caller_env()) { ncol(x) } - .obs <- function() { + .obs <- function() { nrow(x) } .facts <- function() { - if(is.data.frame(x)) + if (is.data.frame(x)) { sum(vapply(x, is.factor, logical(1))) - else - sum(apply(x, 2, is.factor)) # would this always be zero? + } else { + sum(apply(x, 2, is.factor)) + } # would this always be zero? } .dat <- function() { @@ -305,7 +312,7 @@ get_descr_xy <- function(x, y, call = rlang::caller_env()) { } list( - .cols = .cols, + .cols = .cols, .preds = .preds, .obs = .obs, .lvls = .lvls, @@ -317,8 +324,9 @@ get_descr_xy <- function(x, y, call = rlang::caller_env()) { } has_exprs <- function(x) { - if(is.null(x) | is_varying(x) | is_missing_arg(x)) + if (is.null(x) | is_varying(x) | is_missing_arg(x)) { return(FALSE) + } is_symbolic(x) } @@ -334,9 +342,8 @@ requires_descrs <- function(object) { # given a quosure arg, does the expression contain a descriptor function? has_any_descrs <- function(x) { - .x_expr <- rlang::get_expr(x) - .x_env <- rlang::get_env(x, parent.frame()) + .x_env <- rlang::get_env(x, parent.frame()) # evaluated value # required so we don't pass an empty env to findGlobals(), which is an error @@ -358,7 +365,6 @@ has_any_descrs <- function(x) { } is_descr <- function(x) { - descrs <- list( ".cols", ".preds", @@ -377,18 +383,23 @@ is_descr <- function(x) { # descrs = list of functions that actually eval to .cols() poke_descrs <- function(descrs) { - descr_names <- names(descr_env) - old <- purrr::map(descr_names, ~{ - descr_env[[.x]] - }) + old <- purrr::map( + descr_names, + ~ { + descr_env[[.x]] + } + ) names(old) <- descr_names - purrr::walk(descr_names, ~{ - descr_env[[.x]] <- descrs[[.x]] - }) + purrr::walk( + descr_names, + ~ { + descr_env[[.x]] <- descrs[[.x]] + } + ) invisible(old) } @@ -411,13 +422,13 @@ scoped_descrs <- function(descrs, frame = caller_env()) { # with their actual implementations descr_env <- rlang::new_environment( data = list( - .cols = function() cli::cli_abort("Descriptor context not set"), + .cols = function() cli::cli_abort("Descriptor context not set"), .preds = function() cli::cli_abort("Descriptor context not set"), - .obs = function() cli::cli_abort("Descriptor context not set"), - .lvls = function() cli::cli_abort("Descriptor context not set"), + .obs = function() cli::cli_abort("Descriptor context not set"), + .lvls = function() cli::cli_abort("Descriptor context not set"), .facts = function() cli::cli_abort("Descriptor context not set"), - .x = function() cli::cli_abort("Descriptor context not set"), - .y = function() cli::cli_abort("Descriptor context not set"), - .dat = function() cli::cli_abort("Descriptor context not set") + .x = function() cli::cli_abort("Descriptor context not set"), + .y = function() cli::cli_abort("Descriptor context not set"), + .dat = function() cli::cli_abort("Descriptor context not set") ) ) diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index a0bbf5abd..b5d3a051c 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -27,12 +27,16 @@ #' #' @export discrim_flexible <- - function(mode = "classification", num_terms = NULL, prod_degree = NULL, - prune_method = NULL, engine = "earth") { - + function( + mode = "classification", + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL, + engine = "earth" + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), prune_method = enquo(prune_method) ) @@ -60,15 +64,17 @@ discrim_flexible <- #' @inheritParams discrim_flexible #' @export update.discrim_flexible <- - function(object, - num_terms = NULL, - prod_degree = NULL, - prune_method = NULL, - fresh = FALSE, ...) { - + function( + object, + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL, + fresh = FALSE, + ... + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), prune_method = enquo(prune_method) ) @@ -86,12 +92,29 @@ update.discrim_flexible <- #' @export check_args.discrim_flexible <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") - check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") - check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") + check_number_whole( + args$prod_degree, + min = 1, + allow_null = TRUE, + call = call, + arg = "prod_degree" + ) + check_number_whole( + args$num_terms, + min = 1, + allow_null = TRUE, + call = call, + arg = "num_terms" + ) + check_string( + args$prune_method, + allow_empty = FALSE, + allow_null = TRUE, + call = call, + arg = "prune_method" + ) invisible(object) } diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 22eff4ea5..320a5f753 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -30,9 +30,12 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("discrim_linear")} #' @export discrim_linear <- - function(mode = "classification", penalty = NULL, regularization_method = NULL, - engine = "MASS") { - + function( + mode = "classification", + penalty = NULL, + regularization_method = NULL, + engine = "MASS" + ) { args <- list( penalty = rlang::enquo(penalty), regularization_method = rlang::enquo(regularization_method) @@ -57,11 +60,13 @@ discrim_linear <- #' @inheritParams discrim_linear #' @export update.discrim_linear <- - function(object, - penalty = NULL, - regularization_method = NULL, - fresh = FALSE, ...) { - + function( + object, + penalty = NULL, + regularization_method = NULL, + fresh = FALSE, + ... + ) { args <- list( penalty = rlang::enquo(penalty), regularization_method = rlang::enquo(regularization_method) @@ -81,10 +86,15 @@ update.discrim_linear <- #' @export check_args.discrim_linear <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) invisible(object) } diff --git a/R/discrim_quad.R b/R/discrim_quad.R index eb1ad8715..1cadd3d06 100644 --- a/R/discrim_quad.R +++ b/R/discrim_quad.R @@ -28,8 +28,11 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("discrim_quad")} #' @export discrim_quad <- - function(mode = "classification", regularization_method = NULL, engine = "MASS") { - + function( + mode = "classification", + regularization_method = NULL, + engine = "MASS" + ) { args <- list(regularization_method = rlang::enquo(regularization_method)) new_model_spec( @@ -51,10 +54,7 @@ discrim_quad <- #' @inheritParams discrim_quad #' @export update.discrim_quad <- - function(object, - regularization_method = NULL, - fresh = FALSE, ...) { - + function(object, regularization_method = NULL, fresh = FALSE, ...) { args <- list(regularization_method = rlang::enquo(regularization_method)) update_spec( @@ -71,4 +71,3 @@ update.discrim_quad <- set_new_model("discrim_quad") set_model_mode("discrim_quad", "classification") - diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index e6f209c4c..b728bb874 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -45,9 +45,12 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("discrim_regularized")} #' @export discrim_regularized <- - function(mode = "classification", frac_common_cov = NULL, frac_identity = NULL, - engine = "klaR") { - + function( + mode = "classification", + frac_common_cov = NULL, + frac_identity = NULL, + engine = "klaR" + ) { args <- list( frac_common_cov = rlang::enquo(frac_common_cov), frac_identity = rlang::enquo(frac_identity) @@ -72,11 +75,13 @@ discrim_regularized <- #' @inheritParams discrim_regularized #' @export update.discrim_regularized <- - function(object, - frac_common_cov = NULL, - frac_identity = NULL, - fresh = FALSE, ...) { - + function( + object, + frac_common_cov = NULL, + frac_identity = NULL, + fresh = FALSE, + ... + ) { args <- list( frac_common_cov = rlang::enquo(frac_common_cov), frac_identity = rlang::enquo(frac_identity) @@ -96,12 +101,25 @@ update.discrim_regularized <- #' @export check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov") - check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity") - + check_number_decimal( + args$frac_common_cov, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "frac_common_cov" + ) + check_number_decimal( + args$frac_identity, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "frac_identity" + ) + invisible(object) } @@ -110,4 +128,3 @@ check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { set_new_model("discrim_regularized") set_model_mode("discrim_regularized", "classification") - diff --git a/R/engine_docs.R b/R/engine_docs.R index 044d10af2..8643a779d 100644 --- a/R/engine_docs.R +++ b/R/engine_docs.R @@ -38,7 +38,9 @@ knit_engine_docs <- function(pattern = NULL) { errors <- purrr::map_chr(errors, ~ cli::ansi_strip(as.character(.x))) |> purrr::map2_chr(error_nms, ~ paste0(.y, ": ", .x)) |> - purrr::map_chr(~ gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE)) + purrr::map_chr( + ~ gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE) + ) cat("There were failures duing knitting:\n\n") cat(errors) cat("\n\n") @@ -58,8 +60,17 @@ knit_engine_docs <- function(pattern = NULL) { # ------------------------------------------------------------------------------ extensions <- function() { - c("baguette", "censored", "discrim", "multilevelmod", "plsmod", - "poissonreg", "rules", "bonsai", "agua") + c( + "baguette", + "censored", + "discrim", + "multilevelmod", + "plsmod", + "poissonreg", + "rules", + "bonsai", + "agua" + ) } # ------------------------------------------------------------------------------ @@ -108,7 +119,6 @@ update_model_info_file <- function(path = "inst/models.tsv") { # ------------------------------------------------------------------------------ - #' Tools for documenting engines #' #' @description @@ -217,15 +227,17 @@ make_engine_list <- function(mod) { ) eng <- dplyr::left_join(eng, exts, by = "engine") - eng_table <- eng |> dplyr::arrange(.order) |> dplyr::select(-mode) |> dplyr::distinct(engine, .keep_all = TRUE) |> dplyr::mutate( - item = glue::glue(" \\item \\code{\\link[|topic|]{|engine|}|default||has_ext|}", - .open = "|", .close = "|") + item = glue::glue( + " \\item \\code{\\link[|topic|]{|engine|}|default||has_ext|}", + .open = "|", + .close = "|" + ) ) notes <- paste0("\n", cli::symbol$sup_1, " The default engine.") @@ -238,18 +250,29 @@ make_engine_list <- function(mod) { sort() |> combine_words() notes <- paste0( - notes, " ", - cli::symbol$sup_2, " Requires a parsnip extension package for ", - ext_modes, ".") + notes, + " ", + cli::symbol$sup_2, + " Requires a parsnip extension package for ", + ext_modes, + "." + ) } else { - notes <- paste0(notes, " ", cli::symbol$sup_2, " Requires a parsnip extension package.") + notes <- paste0( + notes, + " ", + cli::symbol$sup_2, + " Requires a parsnip extension package." + ) } } - items <- glue::glue_collapse(eng_table$item, sep = "\n") - res <- glue::glue("|main|\n\\itemize{\n|items|\n}\n\n |notes|", - .open = "|", .close = "|") + res <- glue::glue( + "|main|\n\\itemize{\n|items|\n}\n\n |notes|", + .open = "|", + .close = "|" + ) res } @@ -263,21 +286,26 @@ get_default_engine <- function(mod, pkg = "parsnip") { #' @export #' @rdname doc-tools -make_seealso_list <- function(mod, pkg= "parsnip") { +make_seealso_list <- function(mod, pkg = "parsnip") { requireNamespace(pkg, quietly = TRUE) eng <- find_engine_files(mod) - main <- c("\\code{\\link[=fit.model_spec]{fit()}}", - "\\code{\\link[=set_engine]{set_engine()}}", - "\\code{\\link[=update]{update()}}") + main <- c( + "\\code{\\link[=fit.model_spec]{fit()}}", + "\\code{\\link[=set_engine]{set_engine()}}", + "\\code{\\link[=update]{update()}}" + ) if (length(eng) == 0) { return(paste0(main, collapse = ", ")) } res <- - glue::glue("\\code{\\link[|eng$topic|]{|eng$engine| engine details}}", - .open = "|", .close = "|") + glue::glue( + "\\code{\\link[|eng$topic|]{|eng$engine| engine details}}", + .open = "|", + .close = "|" + ) if (pkg != "parsnip") { main <- NULL diff --git a/R/engines.R b/R/engines.R index 4ca62686c..906a09ccc 100644 --- a/R/engines.R +++ b/R/engines.R @@ -1,4 +1,3 @@ - specific_model <- function(x) { cls <- class(x) cls[cls != "model_spec"] @@ -12,8 +11,9 @@ possible_engines <- function(object, ...) { # ------------------------------------------------------------------------------ -shhhh <- function(x) +shhhh <- function(x) { suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE)) +} is_installed <- function(pkg) { res <- try(shhhh(pkg), silent = TRUE) @@ -119,8 +119,13 @@ set_engine.model_spec <- function(object, engine, ...) { # determine if the model specification could feasibly match any entry # in the union of the parsnip model environment and model_info_table. # if not, trigger an error based on the (possibly inferred) model spec slots. - if (!spec_is_possible(spec = object, - engine = object$engine, user_specified_engine = TRUE)) { + if ( + !spec_is_possible( + spec = object, + engine = object$engine, + user_specified_engine = TRUE + ) + ) { check_spec_mode_engine_val(mod_type, object$engine, object$mode) } @@ -128,7 +133,8 @@ set_engine.model_spec <- function(object, engine, ...) { lifecycle::deprecate_warn( "0.1.6", "set_engine(engine = 'cannot be liquidSVM')", - details = "The liquidSVM package is no longer available on CRAN.") + details = "The liquidSVM package is no longer available on CRAN." + ) } new_model_spec( diff --git a/R/extract.R b/R/extract.R index 3ed40fa3b..efbe76e50 100644 --- a/R/extract.R +++ b/R/extract.R @@ -75,7 +75,10 @@ extract_fit_engine.model_fit <- function(x, ...) { if (any(names(x) == "fit")) { return(x$fit) } - cli::cli_abort("The model fit does not have an engine fit.", .internal = TRUE) + cli::cli_abort( + "The model fit does not have an engine fit.", + .internal = TRUE + ) } #' @export @@ -110,7 +113,7 @@ extract_parameter_set_dials.model_spec <- function(x, ...) { ) } -eval_call_info <- function(x) { +eval_call_info <- function(x) { if (!is.null(x)) { # Look for other options allowed_opts <- c("range", "trans", "values") @@ -119,7 +122,10 @@ eval_call_info <- function(x) { } else { opts <- list() } - res <- try(rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), silent = TRUE) + res <- try( + rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), + silent = TRUE + ) if (inherits(res, "try-error")) { stop(paste0("Error when calling ", x$fun, "(): ", as.character(res))) } diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 51a2c7374..e67cd4d1b 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -5,7 +5,6 @@ form_form <- function(object, control, env, ..., call = rlang::caller_env()) { - if (inherits(env$data, "data.frame")) { check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object) @@ -13,12 +12,16 @@ form_form <- encoding_info <- vctrs::vec_slice( encoding_info, - encoding_info$mode == object$mode & encoding_info$engine == object$engine + encoding_info$mode == object$mode & + encoding_info$engine == object$engine ) remove_intercept <- encoding_info |> dplyr::pull(remove_intercept) if (remove_intercept) { - env$data <- env$data[, colnames(env$data) != "(Intercept)", drop = FALSE] + env$data <- env$data[, + colnames(env$data) != "(Intercept)", + drop = FALSE + ] } } @@ -60,13 +63,14 @@ form_form <- res } -xy_xy <- function(object, - env, - control, - target = "none", - ..., - call = rlang::caller_env()) { - +xy_xy <- function( + object, + env, + control, + target = "none", + ..., + call = rlang::caller_env() +) { if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark")) { cli::cli_abort( "spark objects can only be used with the formula interface to {.fun fit}.", @@ -121,9 +125,14 @@ xy_xy <- function(object, res } -form_xy <- function(object, control, env, - target = "none", ..., call = rlang::caller_env()) { - +form_xy <- function( + object, + control, + env, + target = "none", + ..., + call = rlang::caller_env() +) { encoding_info <- get_encoding(class(object)[1]) |> dplyr::filter(mode == object$mode, engine == object$engine) @@ -166,7 +175,6 @@ form_xy <- function(object, control, env, } xy_form <- function(object, env, control, ...) { - check_outcome(env$y, object) encoding_info <- get_encoding(class(object)[1]) diff --git a/R/gen_additive_mod.R b/R/gen_additive_mod.R index 3ed3d4c5d..df469f135 100644 --- a/R/gen_additive_mod.R +++ b/R/gen_additive_mod.R @@ -30,11 +30,12 @@ #' gen_additive_mod() #' #' @export -gen_additive_mod <- function(mode = "unknown", - select_features = NULL, - adjust_deg_free = NULL, - engine = "mgcv") { - +gen_additive_mod <- function( + mode = "unknown", + select_features = NULL, + adjust_deg_free = NULL, + engine = "mgcv" +) { args <- list( select_features = rlang::enquo(select_features), adjust_deg_free = rlang::enquo(adjust_deg_free) @@ -42,7 +43,7 @@ gen_additive_mod <- function(mode = "unknown", new_model_spec( "gen_additive_mod", - args = args, + args = args, eng_args = NULL, mode = mode, user_specified_mode = !missing(mode), @@ -50,18 +51,19 @@ gen_additive_mod <- function(mode = "unknown", engine = engine, user_specified_engine = !missing(engine) ) - } #' @export #' @rdname parsnip_update #' @inheritParams gen_additive_mod -update.gen_additive_mod <- function(object, - select_features = NULL, - adjust_deg_free = NULL, - parameters = NULL, - fresh = FALSE, ...) { - +update.gen_additive_mod <- function( + object, + select_features = NULL, + adjust_deg_free = NULL, + parameters = NULL, + fresh = FALSE, + ... +) { args <- list( select_features = rlang::enquo(select_features), adjust_deg_free = rlang::enquo(adjust_deg_free) @@ -96,10 +98,12 @@ fit_xy.gen_additive_mod <- function(object, ...) { if ("workflows" %in% trace$namespace & identical(object$engine, "mgcv")) { cli::cli_abort( - c("!" = "When working with generalized additive models, please supply the + c( + "!" = "When working with generalized additive models, please supply the model specification to {.fun workflows::add_model} along with a \\ {.arg formula} argument.", - "i" = "See {.help parsnip::model_formula} to learn more."), + "i" = "See {.help parsnip::model_formula} to learn more." + ), call = NULL ) } diff --git a/R/gen_additive_mod_data.R b/R/gen_additive_mod_data.R index ff5f59ff2..d5b4e54a4 100644 --- a/R/gen_additive_mod_data.R +++ b/R/gen_additive_mod_data.R @@ -1,4 +1,3 @@ - set_new_model("gen_additive_mod") set_model_mode("gen_additive_mod", "classification") set_model_mode("gen_additive_mod", "regression") @@ -6,38 +5,43 @@ set_model_mode("gen_additive_mod", "regression") # ------------------------------------------------------------------------------ #### REGRESION ---- set_model_engine(model = "gen_additive_mod", mode = "regression", eng = "mgcv") -set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv", mode = "regression") +set_dependency( + model = "gen_additive_mod", + eng = "mgcv", + pkg = "mgcv", + mode = "regression" +) #Args # TODO make dials PR set_model_arg( - model = "gen_additive_mod", - eng = "mgcv", - parsnip = "select_features", - original = "select", - func = list(pkg = "dials", fun = "select_features"), + model = "gen_additive_mod", + eng = "mgcv", + parsnip = "select_features", + original = "select", + func = list(pkg = "dials", fun = "select_features"), has_submodel = FALSE ) set_model_arg( - model = "gen_additive_mod", - eng = "mgcv", - parsnip = "adjust_deg_free", - original = "gamma", - func = list(pkg = "dials", fun = "adjust_deg_free"), + model = "gen_additive_mod", + eng = "mgcv", + parsnip = "adjust_deg_free", + original = "gamma", + func = list(pkg = "dials", fun = "adjust_deg_free"), has_submodel = FALSE ) set_encoding( model = "gen_additive_mod", - eng = "mgcv", - mode = "regression", + eng = "mgcv", + mode = "regression", options = list( predictor_indicators = "none", - compute_intercept = FALSE, - remove_intercept = FALSE, - allow_sparse_x = FALSE + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE ) ) @@ -54,12 +58,12 @@ set_fit( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "regression", - type = "numeric", - value = list( - pre = NULL, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, post = function(x, object) as.numeric(x), func = c(fun = "predict"), args = list( @@ -71,12 +75,12 @@ set_pred( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "regression", - type = "conf_int", - value = list( - pre = NULL, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, post = linear_lp_to_conf_int, func = c(fun = "predict"), args = list( @@ -89,12 +93,12 @@ set_pred( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "regression", - type = "raw", - value = list( - pre = NULL, + model = "gen_additive_mod", + eng = "mgcv", + mode = "regression", + type = "raw", + value = list( + pre = NULL, post = NULL, func = c(fun = "predict"), args = list( @@ -106,18 +110,27 @@ set_pred( # ------------------------------------------------------------------------------ #### CLASSIFICATION -set_model_engine(model = "gen_additive_mod", mode = "classification", eng = "mgcv") -set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv", mode = "classification") +set_model_engine( + model = "gen_additive_mod", + mode = "classification", + eng = "mgcv" +) +set_dependency( + model = "gen_additive_mod", + eng = "mgcv", + pkg = "mgcv", + mode = "classification" +) set_encoding( model = "gen_additive_mod", - eng = "mgcv", - mode = "classification", + eng = "mgcv", + mode = "classification", options = list( predictor_indicators = "none", - compute_intercept = FALSE, - remove_intercept = FALSE, - allow_sparse_x = FALSE + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE ) ) @@ -136,12 +149,12 @@ set_fit( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "classification", - type = "class", - value = list( - pre = NULL, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "class", + value = list( + pre = NULL, post = function(x, object) { if (is.array(x)) { x <- as.vector(x) @@ -159,11 +172,11 @@ set_pred( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "classification", - type = "prob", - value = list( + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "prob", + value = list( pre = NULL, post = function(x, object) { if (is.array(x)) { @@ -183,12 +196,12 @@ set_pred( ) set_pred( - model = "gen_additive_mod", - eng = "mgcv", - mode = "classification", - type = "raw", - value = list( - pre = NULL, + model = "gen_additive_mod", + eng = "mgcv", + mode = "classification", + type = "raw", + value = list( + pre = NULL, post = NULL, func = c(fun = "predict"), args = list( @@ -208,13 +221,11 @@ set_pred( pre = NULL, post = logistic_lp_to_conf_int, func = c(fun = "predict"), - args = - list( - object = rlang::expr(object$fit), - newdata = rlang::expr(new_data), - type = "link", - se.fit = TRUE - ) + args = list( + object = rlang::expr(object$fit), + newdata = rlang::expr(new_data), + type = "link", + se.fit = TRUE + ) ) ) - diff --git a/R/glm_grouped.R b/R/glm_grouped.R index e8e6d90c1..fe9bb0d75 100644 --- a/R/glm_grouped.R +++ b/R/glm_grouped.R @@ -115,6 +115,11 @@ glm_grouped <- function(formula, data, weights, ...) { values_from = "..weights", values_fill = 0L ) - cl <- rlang::call2("glm", rlang::expr(formula), data = rlang::expr(data), !!!opts) + cl <- rlang::call2( + "glm", + rlang::expr(formula), + data = rlang::expr(data), + !!!opts + ) rlang::eval_tidy(cl) } diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 7bb19e2f1..b48121121 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -11,7 +11,6 @@ # predict_numeric.model_fit() # predict.() - # glmnet call stack using `multi_predict` when object has # classes "_" and "model_fit": # @@ -26,20 +25,25 @@ # predict_raw.model_fit(opts = list(s = penalty)) # predict.() - -predict_glmnet <- function(object, - new_data, - type = NULL, - opts = list(), - penalty = NULL, - multi = FALSE, - ...) { +predict_glmnet <- function( + object, + new_data, + type = NULL, + opts = list(), + penalty = NULL, + multi = FALSE, + ... +) { # See discussion in https://github.com/tidymodels/parsnip/issues/195 if (is.null(penalty) & !is.null(object$spec$args$penalty)) { penalty <- object$spec$args$penalty } - object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) + object$spec$args$penalty <- .check_glmnet_penalty_predict( + penalty, + object, + multi + ) object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) @@ -60,7 +64,7 @@ predict_classprob_glmnet <- function(object, new_data, ...) { predict_classprob.model_fit(object, new_data = new_data, ...) } -predict_raw_glmnet <- function(object, new_data, opts = list(), ...) { +predict_raw_glmnet <- function(object, new_data, opts = list(), ...) { object$spec <- eval_args(object$spec) opts$s <- object$spec$args$penalty @@ -143,7 +147,7 @@ organize_glmnet_pre_pred <- function(x, object) { if (is_sparse_matrix(x)) { return(x) } - + as.matrix(x) } @@ -160,7 +164,7 @@ organize_glmnet_prob <- function(x, object) { organize_multnet_class <- function(x, object) { if (vec_size(x) > 1) { - x <- x[,1] + x <- x[, 1] } else { x <- as.character(x) } @@ -169,20 +173,22 @@ organize_multnet_class <- function(x, object) { organize_multnet_prob <- function(x, object) { if (vec_size(x) > 1) { - x <- as_tibble(x[,,1]) + x <- as_tibble(x[,, 1]) } else { - x <- tibble::as_tibble_row(x[,,1]) + x <- tibble::as_tibble_row(x[,, 1]) } x } # ------------------------------------------------------------------------- -multi_predict_glmnet <- function(object, - new_data, - type = NULL, - penalty = NULL, - ...) { +multi_predict_glmnet <- function( + object, + new_data, + type = NULL, + penalty = NULL, + ... +) { type <- check_pred_type(object, type) check_spec_pred_type(object, type) if (type == "prob") { @@ -211,30 +217,41 @@ multi_predict_glmnet <- function(object, model_type <- class(object$spec)[1] if (object$spec$mode == "classification") { - if (type == "prob" | - model_type == "logistic_reg") { + if ( + type == "prob" | + model_type == "logistic_reg" + ) { dots$type <- "response" } else { dots$type <- type } } - pred <- predict(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) - + pred <- predict( + object, + new_data = new_data, + type = "raw", + opts = dots, + penalty = penalty, + multi = TRUE + ) res <- switch( model_type, "linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty), - "logistic_reg" = format_glmnet_multi_logistic_reg(pred, - penalty = penalty, - type = type, - lvl = object$lvl), - "multinom_reg" = format_glmnet_multi_multinom_reg(pred, - penalty = penalty, - type = type, - lvl = object$lvl, - n_obs = nrow(new_data)) + "logistic_reg" = format_glmnet_multi_logistic_reg( + pred, + penalty = penalty, + type = type, + lvl = object$lvl + ), + "multinom_reg" = format_glmnet_multi_multinom_reg( + pred, + penalty = penalty, + type = type, + lvl = object$lvl, + n_obs = nrow(new_data) + ) ) res @@ -286,9 +303,11 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) { if (type == "class") { pred <- pred |> - dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]), - .pred_class = factor(.pred_class, levels = lvl), - .keep = "unused") + dplyr::mutate( + .pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]), + .pred_class = factor(.pred_class, levels = lvl), + .keep = "unused" + ) } else { pred <- pred |> dplyr::mutate(.pred_class_2 = 1 - .pred) |> @@ -392,8 +411,12 @@ format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) { #' @rdname glmnet_helpers #' @keywords internal #' @export -.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE, - call = rlang::caller_env()) { +.check_glmnet_penalty_predict <- function( + penalty = NULL, + object, + multi = FALSE, + call = rlang::caller_env() +) { if (is.null(penalty)) { penalty <- object$fit$lambda } @@ -440,4 +463,3 @@ set_glmnet_penalty_path <- function(x) { } x } - diff --git a/R/install_packages.R b/R/install_packages.R index 41a8be335..d0e48d9f8 100644 --- a/R/install_packages.R +++ b/R/install_packages.R @@ -1,7 +1,9 @@ # Installs packages needed to run `knit_engine_docs()`. -install_engine_packages <- function(extension = TRUE, extras = TRUE, - ignore_pkgs = c("stats", "liquidSVM", - "parsnip")) { +install_engine_packages <- function( + extension = TRUE, + extras = TRUE, + ignore_pkgs = c("stats", "liquidSVM", "parsnip") +) { bio_pkgs <- c() if (extension) { @@ -26,8 +28,16 @@ install_engine_packages <- function(extension = TRUE, extras = TRUE, } if (extras) { - rmd_pkgs <- c("ape", "broom.mixed", "Cubist", "glmnet", "quantreg", - "rmarkdown", "tidymodels", "xrf") + rmd_pkgs <- c( + "ape", + "broom.mixed", + "Cubist", + "glmnet", + "quantreg", + "rmarkdown", + "tidymodels", + "xrf" + ) engine_packages <- unique(c(engine_packages, rmd_pkgs)) } diff --git a/R/linear_reg.R b/R/linear_reg.R index e0def96fd..ef97fa09c 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -39,11 +39,7 @@ #' linear_reg() #' @export linear_reg <- - function(mode = "regression", - engine = "lm", - penalty = NULL, - mixture = NULL) { - + function(mode = "regression", engine = "lm", penalty = NULL, mixture = NULL) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -84,11 +80,14 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { #' @rdname parsnip_update #' @export update.linear_reg <- - function(object, - parameters = NULL, - penalty = NULL, mixture = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -108,11 +107,23 @@ update.linear_reg <- #' @export check_args.linear_reg <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal( + args$mixture, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "mixture" + ) + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) # ------------------------------------------------------------------------------ # We want to avoid folks passing in a poisson family instead of using @@ -131,7 +142,8 @@ check_args.linear_reg <- function(object, call = rlang::caller_env()) { cli::cli_abort( "A Poisson family was requested for {.fn linear_reg}. Please use {.fn poisson_reg} and the engines in the {.pkg poissonreg} package.", - call = rlang::call2("linear_reg")) + call = rlang::call2("linear_reg") + ) } } diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 35b001138..d75675b97 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -41,13 +41,12 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response", - rankdeficient = "simple" - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response", + rankdeficient = "simple" + ) ) ) @@ -64,14 +63,13 @@ set_pred( setNames(c(".pred_lower", ".pred_upper")) }, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "confidence", - level = expr(level), - type = "response" - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "confidence", + level = expr(level), + type = "response" + ) ) ) set_pred( @@ -87,14 +85,13 @@ set_pred( setNames(c(".pred_lower", ".pred_upper")) }, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "prediction", - level = expr(level), - type = "response" - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "prediction", + level = expr(level), + type = "response" + ) ) ) @@ -149,12 +146,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) ) ) @@ -167,13 +163,12 @@ set_pred( pre = NULL, post = linear_lp_to_conf_int, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - se.fit = TRUE, - type = "link" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + se.fit = TRUE, + type = "link" + ) ) ) @@ -247,13 +242,12 @@ set_pred( pre = NULL, post = .organize_glmnet_pred, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(organize_glmnet_pre_pred(new_data, object)), - type = "response", - s = expr(object$spec$args$penalty) - ) + args = list( + object = expr(object$fit), + newx = expr(organize_glmnet_pre_pred(new_data, object)), + type = "response", + s = expr(object$spec$args$penalty) + ) ) ) @@ -266,9 +260,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = expr(object$fit), - newx = expr(as.matrix(new_data))) + args = list(object = expr(object$fit), newx = expr(as.matrix(new_data))) ) ) @@ -324,28 +316,26 @@ set_pred( post = function(results, object) { res <- tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$pred$conf_int$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$pred$conf_int$extras$level, - lower = FALSE - ), + .pred_lower = convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + .pred_upper = convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) - if (object$spec$method$pred$conf_int$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) { res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + } res }, func = c(pkg = "parsnip", fun = "stan_conf_int"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data) + ) ) ) @@ -359,29 +349,27 @@ set_pred( post = function(results, object) { res <- tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$pred$pred_int$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$pred$pred_int$extras$level, - lower = FALSE - ), + .pred_lower = convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + .pred_upper = convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) - if (object$spec$method$pred$pred_int$extras$std_error) + if (object$spec$method$pred$pred_int$extras$std_error) { res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + } res }, func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - seed = expr(sample.int(10^5, 1)) - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -465,7 +453,6 @@ set_pred( # ------------------------------------------------------------------------------ - set_model_engine("linear_reg", "regression", "keras") set_dependency("linear_reg", "keras", "keras") set_dependency("linear_reg", "keras", "magrittr") @@ -518,7 +505,6 @@ set_pred( # ------------------------------------------------------------------------------ - set_model_engine("linear_reg", "regression", "brulee") set_dependency("linear_reg", "brulee", "brulee") @@ -574,19 +560,27 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) ) ) # ------------------------------------------------------------------------------ -set_model_engine(model = "linear_reg", mode = "quantile regression", eng = "quantreg") -set_dependency(model = "linear_reg", eng = "quantreg", pkg = "quantreg", mode = "quantile regression") +set_model_engine( + model = "linear_reg", + mode = "quantile regression", + eng = "quantreg" +) +set_dependency( + model = "linear_reg", + eng = "quantreg", + pkg = "quantreg", + mode = "quantile regression" +) set_fit( model = "linear_reg", @@ -621,10 +615,9 @@ set_pred( pre = NULL, post = matrix_to_quantile_pred, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data) + ) ) ) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 6d0f6c6d8..e0563a798 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -46,11 +46,12 @@ #' logistic_reg() #' @export logistic_reg <- - function(mode = "classification", - engine = "glm", - penalty = NULL, - mixture = NULL) { - + function( + mode = "classification", + engine = "glm", + penalty = NULL, + mixture = NULL + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -88,18 +89,25 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) { if (engine == "LiblineaR") { # convert parameter arguments new_penalty <- rlang::eval_tidy(x$args$penalty) - if (is.numeric(new_penalty)) - arg_vals$cost <- rlang::new_quosure(1 / new_penalty, env = rlang::empty_env()) + if (is.numeric(new_penalty)) { + arg_vals$cost <- rlang::new_quosure( + 1 / new_penalty, + env = rlang::empty_env() + ) + } if (any(arg_names == "type")) { - if (is.numeric(quo_get_expr(arg_vals$type))) + if (is.numeric(quo_get_expr(arg_vals$type))) { if (quo_get_expr(x$args$mixture) == 0) { - arg_vals$type <- 0 ## ridge + arg_vals$type <- 0 ## ridge } else if (quo_get_expr(x$args$mixture) == 1) { - arg_vals$type <- 6 ## lasso + arg_vals$type <- 6 ## lasso } else { - cli::cli_abort("For the LiblineaR engine, {.arg mixture} must be 0 or 1.") + cli::cli_abort( + "For the LiblineaR engine, {.arg mixture} must be 0 or 1." + ) } + } } x$method$fit$args <- arg_vals } @@ -112,11 +120,14 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) { #' @rdname parsnip_update #' @export update.logistic_reg <- - function(object, - parameters = NULL, - penalty = NULL, mixture = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -136,20 +147,34 @@ update.logistic_reg <- #' @export check_args.logistic_reg <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal( + args$mixture, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "mixture" + ) + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) if (object$engine == "LiblineaR") { if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { cli::cli_abort( - c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\ + c( + "x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\ not {args$mixture}.", "i" = "Choose a pure ridge model with {.code mixture = 0} or \\ a pure lasso model with {.code mixture = 1}.", - "!" = "The {.pkg Liblinear} engine does not support other values."), + "!" = "The {.pkg Liblinear} engine does not support other values." + ), call = call ) } diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 485298300..414ce2c59 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -40,12 +40,11 @@ set_pred( pre = NULL, post = prob_to_class_2, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -62,12 +61,11 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -80,11 +78,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -97,13 +94,12 @@ set_pred( pre = NULL, post = logistic_lp_to_conf_int, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - se.fit = TRUE, - type = "link" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + se.fit = TRUE, + type = "link" + ) ) ) @@ -163,13 +159,15 @@ set_pred( pre = NULL, post = organize_glmnet_class, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])), - type = "response", - s = expr(object$spec$args$penalty) - ) + args = list( + object = expr(object$fit), + newx = expr(as.matrix(new_data[, + rownames(object$fit$beta), + drop = FALSE + ])), + type = "response", + s = expr(object$spec$args$penalty) + ) ) ) @@ -182,13 +180,12 @@ set_pred( pre = NULL, post = organize_glmnet_prob, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) + args = list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) ) ) @@ -201,11 +198,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) ) ) @@ -266,11 +262,10 @@ set_pred( pre = NULL, post = liblinear_preds, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = expr(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + newx = expr(as.matrix(new_data)) + ) ) ) @@ -283,12 +278,11 @@ set_pred( pre = NULL, post = liblinear_probs, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = expr(as.matrix(new_data)), - proba = TRUE - ) + args = list( + object = quote(object$fit), + newx = expr(as.matrix(new_data)), + proba = TRUE + ) ) ) @@ -303,7 +297,8 @@ set_pred( func = c(fun = "predict"), args = list( object = quote(object$fit), - newx = quote(new_data)) + newx = quote(new_data) + ) ) ) @@ -339,10 +334,9 @@ set_fit( data = c(formula = "formula", data = "x", weights = "weight_col"), protect = c("x", "formula", "weights"), func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "binomial" - ) + defaults = list( + family = "binomial" + ) ) ) @@ -367,11 +361,10 @@ set_pred( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) + args = list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) ) @@ -384,11 +377,10 @@ set_pred( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) + args = list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) ) @@ -440,11 +432,10 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), - args = - list( - object = quote(object), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object), + x = quote(as.matrix(new_data)) + ) ) ) @@ -461,11 +452,10 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) ) ) @@ -511,11 +501,10 @@ set_pred( unname(x) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -533,11 +522,10 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -551,11 +539,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -569,17 +556,15 @@ set_pred( post = function(results, object) { res_2 <- tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$pred$conf_int$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$pred$conf_int$extras$level, - lower = FALSE - ), + lo = convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + hi = convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) res_1 <- res_2 res_1$lo <- 1 - res_2$hi @@ -596,11 +581,10 @@ set_pred( res }, func = c(pkg = "parsnip", fun = "stan_conf_int"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) + args = list( + object = expr(object$fit), + newdata = expr(new_data) + ) ) ) @@ -614,17 +598,15 @@ set_pred( post = function(results, object) { res_2 <- tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$pred$pred_int$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$pred$pred_int$extras$level, - lower = FALSE - ), + lo = convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + hi = convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) res_1 <- res_2 res_1$lo <- 1 - res_2$hi @@ -635,23 +617,22 @@ set_pred( colnames(res_2) <- c(lo_nms[2], hi_nms[2]) res <- bind_cols(res_1, res_2) - if (object$spec$method$pred$pred_int$extras$std_error) + if (object$spec$method$pred$pred_int$extras$std_error) { res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + } res }, func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) # ------------------------------------------------------------------------------ - set_model_engine("logistic_reg", "classification", "brulee") set_dependency("logistic_reg", "brulee", "brulee") @@ -706,12 +687,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -724,12 +704,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) ) ) - diff --git a/R/mars.R b/R/mars.R index f2d125873..42c05b805 100644 --- a/R/mars.R +++ b/R/mars.R @@ -31,12 +31,16 @@ #' mars(mode = "regression", num_terms = 5) #' @export mars <- - function(mode = "unknown", engine = "earth", - num_terms = NULL, prod_degree = NULL, prune_method = NULL) { - + function( + mode = "unknown", + engine = "earth", + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), prune_method = enquo(prune_method) ) @@ -58,14 +62,18 @@ mars <- #' @rdname parsnip_update #' @export update.mars <- - function(object, - parameters = NULL, - num_terms = NULL, prod_degree = NULL, prune_method = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + num_terms = NULL, + prod_degree = NULL, + prune_method = NULL, + fresh = FALSE, + ... + ) { args <- list( - num_terms = enquo(num_terms), - prod_degree = enquo(prod_degree), + num_terms = enquo(num_terms), + prod_degree = enquo(prod_degree), prune_method = enquo(prune_method) ) @@ -106,12 +114,29 @@ translate.mars <- function(x, engine = x$engine, ...) { #' @export check_args.mars <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") - check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") - check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") + check_number_whole( + args$prod_degree, + min = 1, + allow_null = TRUE, + call = call, + arg = "prod_degree" + ) + check_number_whole( + args$num_terms, + min = 1, + allow_null = TRUE, + call = call, + arg = "num_terms" + ) + check_string( + args$prune_method, + allow_empty = FALSE, + allow_null = TRUE, + call = call, + arg = "prune_method" + ) invisible(object) } @@ -147,8 +172,9 @@ multi_predict._earth <- function(object, new_data, type = NULL, num_terms = NULL, ...) { load_libs(object, quiet = TRUE, attach = TRUE) - if (is.null(num_terms)) + if (is.null(num_terms)) { num_terms <- object$fit$selected.terms[-1] + } num_terms <- sort(num_terms) @@ -157,13 +183,16 @@ multi_predict._earth <- call_names <- names(object$fit$call) call_names <- call_names[!(call_names %in% c("", "x", "y"))] for (i in call_names) { - if (is_quosure(object$fit$call[[i]])) + if (is_quosure(object$fit$call[[i]])) { object$fit$call[[i]] <- eval_tidy(object$fit$call[[i]]) + } } msg <- - c("x" = "Please use {.code keepxy = TRUE} as an option to enable submodel - predictions with earth.") + c( + "x" = "Please use {.code keepxy = TRUE} as an option to enable submodel + predictions with earth." + ) if (any(names(object$fit$call) == "keepxy")) { if (!isTRUE(object$fit$call$keepxy)) { cli::cli_abort(msg) @@ -181,8 +210,14 @@ multi_predict._earth <- } res <- - map(num_terms, earth_by_terms, object = object, - new_data = new_data, type = type, ...) |> + map( + num_terms, + earth_by_terms, + object = object, + new_data = new_data, + type = type, + ... + ) |> purrr::list_rbind() res <- arrange(res, .row, num_terms) res <- split(res[, -1], res$.row) diff --git a/R/mars_data.R b/R/mars_data.R index 68079dae6..3e2e03020 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -1,4 +1,3 @@ - set_new_model("mars") set_model_mode("mars", "classification") @@ -92,12 +91,11 @@ set_pred( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -110,9 +108,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -128,12 +124,11 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -151,12 +146,11 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -169,8 +163,6 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) diff --git a/R/misc.R b/R/misc.R index fd6f48016..1d79ec16e 100644 --- a/R/misc.R +++ b/R/misc.R @@ -38,7 +38,7 @@ is_missing_arg <- function(x) { # these checks. engine_filter_condition <- function(engine, user_specified_engine, data) { # use !isTRUE so that result is TRUE if is.null(user_specified_engine) - if (!isTRUE(user_specified_engine) || is.null(engine)) { + if (!isTRUE(user_specified_engine) || is.null(engine)) { return(TRUE) } @@ -48,7 +48,7 @@ engine_filter_condition <- function(engine, user_specified_engine, data) { # analogous helper for modes to `engine_filter_condition()` mode_filter_condition <- function(mode, user_specified_mode, data) { # use !isTRUE so that result is TRUE if is.null(user_specified_mode) - if (!isTRUE(user_specified_mode) || is.null(mode)) { + if (!isTRUE(user_specified_mode) || is.null(mode)) { return(TRUE) } @@ -91,43 +91,47 @@ mode_filter_condition <- function(mode, user_specified_mode, data) { #' @export #' @keywords internal #' @rdname extension-check-helpers -spec_is_possible <- function(spec, - engine = spec$engine, - user_specified_engine = spec$user_specified_engine, - mode = spec$mode, - user_specified_mode = spec$user_specified_mode) { +spec_is_possible <- function( + spec, + engine = spec$engine, + user_specified_engine = spec$user_specified_engine, + mode = spec$mode, + user_specified_mode = spec$user_specified_mode +) { cls <- class(spec)[[1]] model_env <- rlang::env_get(get_model_env(), cls) model_env_matches <- model_env model_env_matches$model <- cls model_info_table_matches <- - vctrs::vec_slice(model_info_table, - model_info_table$model == cls) + vctrs::vec_slice(model_info_table, model_info_table$model == cls) if (isTRUE(user_specified_engine) && !is.null(engine)) { model_env_matches <- - vctrs::vec_slice(model_env_matches, - model_env_matches$engine == engine) + vctrs::vec_slice(model_env_matches, model_env_matches$engine == engine) model_info_table_matches <- - vctrs::vec_slice(model_info_table_matches, - model_info_table_matches$engine == engine) + vctrs::vec_slice( + model_info_table_matches, + model_info_table_matches$engine == engine + ) } if (isTRUE(user_specified_mode) && !is.null(mode)) { model_env_matches <- - vctrs::vec_slice(model_env_matches, - model_env_matches$mode == mode) + vctrs::vec_slice(model_env_matches, model_env_matches$mode == mode) model_info_table_matches <- - vctrs::vec_slice(model_info_table_matches, - model_info_table_matches$mode == mode) + vctrs::vec_slice( + model_info_table_matches, + model_info_table_matches$mode == mode + ) } - - if (vctrs::vec_size(model_env_matches) > 0 || - vctrs::vec_size(model_info_table_matches) > 0) { + if ( + vctrs::vec_size(model_env_matches) > 0 || + vctrs::vec_size(model_info_table_matches) > 0 + ) { return(TRUE) } @@ -138,11 +142,13 @@ spec_is_possible <- function(spec, #' @export #' @keywords internal #' @rdname extension-check-helpers -spec_is_loaded <- function(spec, - engine = spec$engine, - user_specified_engine = spec$user_specified_engine, - mode = spec$mode, - user_specified_mode = spec$user_specified_mode) { +spec_is_loaded <- function( + spec, + engine = spec$engine, + user_specified_engine = spec$user_specified_engine, + mode = spec$mode, + user_specified_mode = spec$user_specified_mode +) { cls <- class(spec)[[1]] avail <- get_from_env(cls) @@ -151,7 +157,11 @@ spec_is_loaded <- function(spec, return(FALSE) } - engine_condition <- engine_filter_condition(engine, user_specified_engine, avail) + engine_condition <- engine_filter_condition( + engine, + user_specified_engine, + avail + ) mode_condition <- mode_filter_condition(mode, user_specified_mode, avail) avail <- avail |> @@ -179,25 +189,40 @@ is_printable_spec <- function(x) { #' @export #' @keywords internal #' @rdname extension-check-helpers -prompt_missing_implementation <- function(spec, - engine = spec$engine, - user_specified_engine = spec$user_specified_engine, - mode = spec$mode, - user_specified_mode = spec$user_specified_mode, - prompt, ...) { +prompt_missing_implementation <- function( + spec, + engine = spec$engine, + user_specified_engine = spec$user_specified_engine, + mode = spec$mode, + user_specified_mode = spec$user_specified_mode, + prompt, + ... +) { cls <- class(spec)[[1]] avail <- get_from_env(cls) - engine_condition <- engine_filter_condition(engine, user_specified_engine, avail) + engine_condition <- engine_filter_condition( + engine, + user_specified_engine, + avail + ) mode_condition <- mode_filter_condition(mode, user_specified_mode, avail) if (!is.null(avail)) { avail <- vctrs::vec_slice(avail, mode_condition & engine_condition) } - engine_condition_all <- engine_filter_condition(engine, user_specified_engine, model_info_table) - mode_condition_all <- mode_filter_condition(mode, user_specified_mode, model_info_table) + engine_condition_all <- engine_filter_condition( + engine, + user_specified_engine, + model_info_table + ) + mode_condition_all <- mode_filter_condition( + mode, + user_specified_mode, + model_info_table + ) all <- vctrs::vec_slice( @@ -210,13 +235,15 @@ prompt_missing_implementation <- function(spec, all <- all[setdiff(names(all), "model")] - if (!isTRUE(user_specified_mode)) {mode <- ""} + if (!isTRUE(user_specified_mode)) { + mode <- "" + } msg <- c( "!" = "{.pkg parsnip} could not locate an implementation for `{cls}` {mode} \\ model specifications{if (isTRUE(user_specified_engine)) { paste0(' using the `', engine, '` engine')} else {''}}." - ) + ) if (nrow(avail) == 0 && nrow(all) > 0) { pkgs <- unique(all$pkg) @@ -224,8 +251,10 @@ prompt_missing_implementation <- function(spec, msg <- c( msg, - "i" = paste0("{cli::qty(pkgs)}The parsnip extension package{?s} {.pkg {pkgs}}", - " implemen{?ts/t} support for this specification."), + "i" = paste0( + "{cli::qty(pkgs)}The parsnip extension package{?s} {.pkg {pkgs}}", + " implemen{?ts/t} support for this specification." + ), "i" = "Please install (if needed) and load to continue." ) } @@ -324,7 +353,9 @@ update_dot_check <- function(...) { dots <- enquos(...) if (length(dots) > 0) { - cli::cli_abort("The extra argument{?s} {.arg {names(dots)}} will be ignored.") + cli::cli_abort( + "The extra argument{?s} {.arg {names(dots)}} will be ignored." + ) } invisible(NULL) @@ -335,15 +366,27 @@ update_dot_check <- function(...) { #' @export #' @keywords internal #' @rdname add_on_exports -new_model_spec <- function(cls, args, eng_args, mode, user_specified_mode = TRUE, - method, engine, user_specified_engine = TRUE) { +new_model_spec <- function( + cls, + args, + eng_args, + mode, + user_specified_mode = TRUE, + method, + engine, + user_specified_engine = TRUE +) { # determine if the model specification could feasibly match any entry # in the union of the parsnip model environment and model_info_table. # if not, trigger an error based on the (possibly inferred) model spec slots. out <- list( - args = args, eng_args = eng_args, - mode = mode, user_specified_mode = user_specified_mode, method = method, - engine = engine, user_specified_engine = user_specified_engine + args = args, + eng_args = eng_args, + mode = mode, + user_specified_mode = user_specified_mode, + method = method, + engine = engine, + user_specified_engine = user_specified_engine ) class(out) <- make_classes(cls) @@ -361,18 +404,28 @@ check_outcome <- function(y, spec) { return(invisible(NULL)) } - has_no_outcome <- if (is.atomic(y)) {is.null(y)} else {length(y) == 0} + has_no_outcome <- if (is.atomic(y)) { + is.null(y) + } else { + length(y) == 0 + } if (isTRUE(has_no_outcome)) { cli::cli_abort( - c("!" = "{.fun {class(spec)[1]}} was unable to find an outcome.", + c( + "!" = "{.fun {class(spec)[1]}} was unable to find an outcome.", "i" = "Ensure that you have specified an outcome column and that it \\ - hasn't been removed in pre-processing."), + hasn't been removed in pre-processing." + ), call = NULL ) } if (spec$mode == "regression") { - outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))} + outcome_is_numeric <- if (is.atomic(y)) { + is.numeric(y) + } else { + all(map_lgl(y, is.numeric)) + } if (!outcome_is_numeric) { cli::cli_abort( "For a regression model, the outcome should be {.cls numeric}, not @@ -382,7 +435,11 @@ check_outcome <- function(y, spec) { } if (spec$mode == "classification") { - outcome_is_factor <- if (is.atomic(y)) {is.factor(y)} else {all(map_lgl(y, is.factor))} + outcome_is_factor <- if (is.atomic(y)) { + is.factor(y) + } else { + all(map_lgl(y, is.factor)) + } if (!outcome_is_factor) { cli::cli_abort( "For a classification model, the outcome should be a {.cls factor}, not @@ -390,7 +447,9 @@ check_outcome <- function(y, spec) { ) } - if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) { + if ( + inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2 + ) { # warn rather than error since some engines handle this case by binning # all but the first level as the non-event, so this may be intended cli::cli_warn(c( @@ -425,16 +484,25 @@ check_final_param <- function(x, call = rlang::caller_env()) { return(invisible(x)) } if (!is.list(x) & !tibble::is_tibble(x)) { - cli::cli_abort("The parameter object should be a list or tibble.", call = call) + cli::cli_abort( + "The parameter object should be a list or tibble.", + call = call + ) } if (tibble::is_tibble(x) && nrow(x) > 1) { - cli::cli_abort("The parameter tibble should have a single row.", call = call) + cli::cli_abort( + "The parameter tibble should have a single row.", + call = call + ) } if (tibble::is_tibble(x)) { x <- as.list(x) } if (length(names) == 0 || any(names(x) == "")) { - cli::cli_abort("All values in {.arg parameters} should have a name.", call = call) + cli::cli_abort( + "All values in {.arg parameters} should have a name.", + call = call + ) } invisible(x) @@ -495,14 +563,16 @@ update_engine_parameters <- function(eng_args, fresh, ...) { stan_conf_int <- function(object, newdata) { check_installs(list(method = list(libs = "rstanarm"))) if (utils::packageVersion("rstanarm") >= "2.21.1") { - fn <- rlang::call2("posterior_epred", + fn <- rlang::call2( + "posterior_epred", .ns = "rstanarm", object = expr(object), newdata = expr(newdata), seed = expr(sample.int(10^5, 1)) ) } else { - fn <- rlang::call2("posterior_linpred", + fn <- rlang::call2( + "posterior_linpred", .ns = "rstanarm", object = expr(object), newdata = expr(newdata), @@ -570,8 +640,7 @@ check_for_newdata <- function(..., call = rlang::caller_env()) { is_cran_check <- function() { if (identical(Sys.getenv("NOT_CRAN"), "true")) { FALSE - } - else { + } else { Sys.getenv("_R_CHECK_PACKAGE_NAME_", "") != "" } } @@ -599,8 +668,10 @@ is_cran_check <- function() { #' @export .get_prediction_column_names <- function(x, syms = FALSE) { if (!inherits(x, c("model_fit", "workflow"))) { - cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or - {.cls workflow}, not {.obj_type_friendly {x}}.") + cli::cli_abort( + "{.arg x} should be an object with class {.cls model_fit} or + {.cls workflow}, not {.obj_type_friendly {x}}." + ) } if (inherits(x, "workflow")) { @@ -619,10 +690,12 @@ is_cran_check <- function() { purrr::pluck("type") if (length(predict_types) == 0) { - cli::cli_abort("Prediction information could not be found for this + cli::cli_abort( + "Prediction information could not be found for this {.fn {model_type}} with engine {.val {model_engine}} and mode {.val {model_mode}}. Does a parsnip extension package need to - be loaded?") + be loaded?" + ) } res <- list(estimate = character(0), probabilities = character(0)) diff --git a/R/mlp.R b/R/mlp.R index cb406e99f..8c69124be 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -38,17 +38,23 @@ #' @export mlp <- - function(mode = "unknown", engine = "nnet", - hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, - activation = NULL, learn_rate = NULL) { - + function( + mode = "unknown", + engine = "nnet", + hidden_units = NULL, + penalty = NULL, + dropout = NULL, + epochs = NULL, + activation = NULL, + learn_rate = NULL + ) { args <- list( hidden_units = enquo(hidden_units), - penalty = enquo(penalty), - dropout = enquo(dropout), - epochs = enquo(epochs), - activation = enquo(activation), - learn_rate = enquo(learn_rate) + penalty = enquo(penalty), + dropout = enquo(dropout), + epochs = enquo(epochs), + activation = enquo(activation), + learn_rate = enquo(learn_rate) ) new_model_spec( @@ -69,19 +75,25 @@ mlp <- #' @rdname parsnip_update #' @export update.mlp <- - function(object, - parameters = NULL, - hidden_units = NULL, penalty = NULL, dropout = NULL, - epochs = NULL, activation = NULL, learn_rate = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + hidden_units = NULL, + penalty = NULL, + dropout = NULL, + epochs = NULL, + activation = NULL, + learn_rate = NULL, + fresh = FALSE, + ... + ) { args <- list( hidden_units = enquo(hidden_units), - penalty = enquo(penalty), - dropout = enquo(dropout), - epochs = enquo(epochs), - activation = enquo(activation), - learn_rate = enquo(learn_rate) + penalty = enquo(penalty), + dropout = enquo(dropout), + epochs = enquo(epochs), + activation = enquo(activation), + learn_rate = enquo(learn_rate) ) update_spec( @@ -104,7 +116,7 @@ translate.mlp <- function(x, engine = x$engine, ...) { } if (engine == "nnet") { - if(isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) { + if (isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) { x$args$hidden_units <- 5 } } @@ -113,11 +125,13 @@ translate.mlp <- function(x, engine = x$engine, ...) { if (engine == "nnet") { if (x$mode == "classification") { - if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) + if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) { x$method$fit$args$linout <- FALSE + } } else { - if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) + if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) { x$method$fit$args$linout <- TRUE + } } } x @@ -127,14 +141,30 @@ translate.mlp <- function(x, engine = x$engine, ...) { #' @export check_args.mlp <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") - check_number_decimal(args$dropout, min = 0, max = 1, allow_null = TRUE, call = call, arg = "dropout") + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) + check_number_decimal( + args$dropout, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "dropout" + ) - if (is.numeric(args$penalty) && is.numeric(args$dropout) && - args$dropout > 0 && args$penalty > 0) { + if ( + is.numeric(args$penalty) && + is.numeric(args$dropout) && + args$dropout > 0 && + args$penalty > 0 + ) { cli::cli_abort( "Both weight decay and dropout should not be specified.", call = call @@ -146,11 +176,13 @@ check_args.mlp <- function(object, call = rlang::caller_env()) { # keras wrapper for feed-forward nnet -class2ind <- function (x, drop2nd = FALSE, call = rlang::caller_env()) { +class2ind <- function(x, drop2nd = FALSE, call = rlang::caller_env()) { if (!is.factor(x)) { - cli::cli_abort(c("x" = "{.arg x} should be a {cls factor} not {.obj_type_friendly {x}.")) + cli::cli_abort(c( + "x" = "{.arg x} should be a {cls factor} not {.obj_type_friendly {x}." + )) } - y <- model.matrix( ~ x - 1) + y <- model.matrix(~ x - 1) colnames(y) <- gsub("^x", "", colnames(y)) attributes(y)$assign <- NULL attributes(y)$contrasts <- NULL @@ -187,11 +219,17 @@ class2ind <- function (x, drop2nd = FALSE, call = rlang::caller_env()) { #' @keywords internal #' @export keras_mlp <- - function(x, y, - hidden_units = 5, penalty = 0, dropout = 0, epochs = 20, activation = "softmax", - seeds = sample.int(10^5, size = 3), - ...) { - + function( + x, + y, + hidden_units = 5, + penalty = 0, + dropout = 0, + epochs = 20, + activation = "softmax", + seeds = sample.int(10^5, size = 3), + ... + ) { allowed_keras_activation <- keras_activations() good_activation <- activation %in% allowed_keras_activation if (!all(good_activation)) { @@ -233,7 +271,9 @@ keras_mlp <- activation = activation, input_shape = ncol(x), kernel_regularizer = keras::regularizer_l2(penalty), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + kernel_initializer = keras::initializer_glorot_uniform( + seed = seeds[1] + ) ) } else { model |> @@ -241,7 +281,9 @@ keras_mlp <- units = hidden_units, activation = activation, input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + kernel_initializer = keras::initializer_glorot_uniform( + seed = seeds[1] + ) ) } @@ -251,7 +293,9 @@ keras_mlp <- units = hidden_units, activation = activation, input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + kernel_initializer = keras::initializer_glorot_uniform( + seed = seeds[1] + ) ) |> keras::layer_dropout(rate = dropout, seed = seeds[2]) } @@ -261,14 +305,18 @@ keras_mlp <- keras::layer_dense( units = ncol(y), activation = 'softmax', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + kernel_initializer = keras::initializer_glorot_uniform( + seed = seeds[3] + ) ) } else { model <- model |> keras::layer_dense( units = ncol(y), activation = 'linear', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + kernel_initializer = keras::initializer_glorot_uniform( + seed = seeds[3] + ) ) } @@ -303,10 +351,11 @@ keras_mlp <- nnet_softmax <- function(results, object) { - if (ncol(results) == 1) + if (ncol(results) == 1) { results <- cbind(1 - results, results) + } - results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- apply(results, 1, function(x) exp(x) / sum(exp(x))) results <- t(results) colnames(results) <- object$lvl results <- as_tibble(results) @@ -348,12 +397,25 @@ parse_keras_args <- function(...) { } mlp_num_weights <- function(p, hidden_units, classes) { - ((p + 1) * hidden_units) + ((hidden_units+1) * classes) + ((p + 1) * hidden_units) + ((hidden_units + 1) * classes) } allowed_keras_activation <- - c("elu", "exponential", "gelu", "hardsigmoid", "linear", "relu", "selu", - "sigmoid", "softmax", "softplus", "softsign", "swish", "tanh") + c( + "elu", + "exponential", + "gelu", + "hardsigmoid", + "linear", + "relu", + "selu", + "sigmoid", + "softmax", + "softplus", + "softsign", + "swish", + "tanh" + ) #' Activation functions for neural networks in keras #' @@ -382,22 +444,26 @@ multi_predict._torch_mlp <- function(object, new_data, type = NULL, epochs = NULL, ...) { load_libs(object, quiet = TRUE, attach = TRUE) - if (is.null(epochs)) + if (is.null(epochs)) { epochs <- length(object$fit$models) + } epochs <- sort(epochs) if (is.null(type)) { - if (object$spec$mode == "classification") + if (object$spec$mode == "classification") { type <- "class" - else + } else { type <- "numeric" + } } res <- - purrr::map(epochs, - ~ predict(object, new_data, type, epochs = .x) |> - dplyr::mutate(epochs = .x)) |> + purrr::map( + epochs, + ~ predict(object, new_data, type, epochs = .x) |> + dplyr::mutate(epochs = .x) + ) |> purrr::map(\(x) x |> dplyr::mutate(.row = seq_len(nrow(new_data)))) |> purrr::list_rbind() |> dplyr::arrange(.row, epochs) @@ -408,14 +474,13 @@ multi_predict._torch_mlp <- reformat_torch_num <- function(results, object) { - if (isTRUE(ncol(results) > 1)) { nms <- colnames(results) results <- as_tibble(results, .name_repair = "minimal") if (length(nms) == 0 && length(object$preproc$y_var) == ncol(results)) { names(results) <- object$preproc$y_var } - } else { + } else { results <- unname(results[[1]]) } results diff --git a/R/mlp_data.R b/R/mlp_data.R index 4d3e5039d..98bda40fa 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -1,4 +1,3 @@ - set_new_model("mlp") set_model_mode("mlp", "classification") @@ -113,11 +112,10 @@ set_pred( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) ) ) @@ -130,13 +128,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) ) - ) set_pred( @@ -148,11 +144,10 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), - args = - list( - object = quote(object), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object), + x = quote(as.matrix(new_data)) + ) ) ) @@ -169,11 +164,10 @@ set_pred( x }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) ) ) @@ -186,11 +180,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) ) ) @@ -284,12 +277,11 @@ set_pred( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) ) @@ -302,13 +294,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) - ) set_pred( @@ -320,12 +310,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) ) ) @@ -338,12 +327,11 @@ set_pred( pre = NULL, post = nnet_softmax, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) ) @@ -356,11 +344,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -421,7 +408,11 @@ set_model_arg( eng = "brulee", parsnip = "activation", original = "activation", - func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')), + func = list( + pkg = "dials", + fun = "activation", + values = c('relu', 'elu', 'tanh') + ), has_submodel = FALSE ) @@ -483,12 +474,11 @@ set_pred( pre = NULL, post = reformat_torch_num, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) ) ) @@ -501,12 +491,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -519,12 +508,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) ) ) @@ -584,7 +572,11 @@ set_model_arg( eng = "brulee_two_layer", parsnip = "activation", original = "activation", - func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')), + func = list( + pkg = "dials", + fun = "activation", + values = c('relu', 'elu', 'tanh') + ), has_submodel = FALSE ) @@ -646,12 +638,11 @@ set_pred( pre = NULL, post = reformat_torch_num, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) ) ) @@ -664,12 +655,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -682,12 +672,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) ) ) - diff --git a/R/model_object_docs.R b/R/model_object_docs.R index 7fcc6a377..e4a3c9dcd 100644 --- a/R/model_object_docs.R +++ b/R/model_object_docs.R @@ -221,4 +221,3 @@ NULL #' #' nrow(fit_obj$fit$x) NULL - diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 1a8f0e8a1..4dcd7b149 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -46,11 +46,12 @@ #' multinom_reg() #' @export multinom_reg <- - function(mode = "classification", - engine = "nnet", - penalty = NULL, - mixture = NULL) { - + function( + mode = "classification", + engine = "nnet", + penalty = NULL, + mixture = NULL + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -77,11 +78,14 @@ translate.multinom_reg <- translate.linear_reg #' @rdname parsnip_update #' @export update.multinom_reg <- - function(object, - parameters = NULL, - penalty = NULL, mixture = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -101,11 +105,23 @@ update.multinom_reg <- #' @export check_args.multinom_reg <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal( + args$mixture, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "mixture" + ) + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) invisible(object) } diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 9f805f16d..9b5982f26 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -58,13 +58,15 @@ set_pred( pre = NULL, post = organize_multnet_class, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data[, rownames(object$fit$beta[[1]]), drop = FALSE])), - type = "class", - s = quote(object$spec$args$penalty) - ) + args = list( + object = quote(object$fit), + newx = quote(as.matrix(new_data[, + rownames(object$fit$beta[[1]]), + drop = FALSE + ])), + type = "class", + s = quote(object$spec$args$penalty) + ) ) ) @@ -77,13 +79,15 @@ set_pred( pre = NULL, post = organize_multnet_prob, func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data[, rownames(object$fit$beta[[1]]), drop = FALSE])), - type = "response", - s = expr(object$spec$args$penalty) - ) + args = list( + object = expr(object$fit), + newx = expr(as.matrix(new_data[, + rownames(object$fit$beta[[1]]), + drop = FALSE + ])), + type = "response", + s = expr(object$spec$args$penalty) + ) ) ) @@ -96,11 +100,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) ) ) @@ -161,11 +164,10 @@ set_pred( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) + args = list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) ) @@ -179,11 +181,10 @@ set_pred( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) + args = list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) ) @@ -236,9 +237,7 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "keras_predict_classes"), - args = - list(object = quote(object), - x = quote(as.matrix(new_data))) + args = list(object = quote(object), x = quote(as.matrix(new_data))) ) ) @@ -255,9 +254,7 @@ set_pred( x }, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - x = quote(as.matrix(new_data))) + args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) ) ) @@ -309,12 +306,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) ) ) @@ -327,12 +323,11 @@ set_pred( pre = NULL, post = organize_nnet_prob, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) ) ) @@ -345,18 +340,16 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) # ------------------------------------------------------------------------------ - set_model_engine("multinom_reg", "classification", "brulee") set_dependency("multinom_reg", "brulee", "brulee") @@ -412,12 +405,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -430,11 +422,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) ) ) diff --git a/R/naive_Bayes.R b/R/naive_Bayes.R index ffe91ae66..7794e24b7 100644 --- a/R/naive_Bayes.R +++ b/R/naive_Bayes.R @@ -28,7 +28,12 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("naive_Bayes")} #' @export naive_Bayes <- - function(mode = "classification", smoothness = NULL, Laplace = NULL, engine = "klaR") { + function( + mode = "classification", + smoothness = NULL, + Laplace = NULL, + engine = "klaR" + ) { args <- list( smoothness = rlang::enquo(smoothness), @@ -54,14 +59,11 @@ naive_Bayes <- #' @inheritParams naive_Bayes #' @export update.naive_Bayes <- - function(object, - smoothness = NULL, Laplace = NULL, - fresh = FALSE, ...) { - + function(object, smoothness = NULL, Laplace = NULL, fresh = FALSE, ...) { args <- list( smoothness = rlang::enquo(smoothness), - Laplace = rlang::enquo(Laplace) + Laplace = rlang::enquo(Laplace) ) update_spec( diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index e09dba17b..5f5354a1c 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -39,15 +39,17 @@ #' nearest_neighbor(neighbors = 11) #' #' @export -nearest_neighbor <- function(mode = "unknown", - engine = "kknn", - neighbors = NULL, - weight_func = NULL, - dist_power = NULL) { +nearest_neighbor <- function( + mode = "unknown", + engine = "kknn", + neighbors = NULL, + weight_func = NULL, + dist_power = NULL +) { args <- list( - neighbors = enquo(neighbors), + neighbors = enquo(neighbors), weight_func = enquo(weight_func), - dist_power = enquo(dist_power) + dist_power = enquo(dist_power) ) new_model_spec( @@ -67,17 +69,19 @@ nearest_neighbor <- function(mode = "unknown", #' @method update nearest_neighbor #' @export #' @rdname parsnip_update -update.nearest_neighbor <- function(object, - parameters = NULL, - neighbors = NULL, - weight_func = NULL, - dist_power = NULL, - fresh = FALSE, ...) { - +update.nearest_neighbor <- function( + object, + parameters = NULL, + neighbors = NULL, + weight_func = NULL, + dist_power = NULL, + fresh = FALSE, + ... +) { args <- list( - neighbors = enquo(neighbors), + neighbors = enquo(neighbors), weight_func = enquo(weight_func), - dist_power = enquo(dist_power) + dist_power = enquo(dist_power) ) update_spec( @@ -94,12 +98,22 @@ update.nearest_neighbor <- function(object, #' @export check_args.nearest_neighbor <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$neighbors, min = 0, allow_null = TRUE, call = call, arg = "neighbors") - check_string(args$weight_func, allow_null = TRUE, call = call, arg = "weight_func") - + check_number_whole( + args$neighbors, + min = 0, + allow_null = TRUE, + call = call, + arg = "neighbors" + ) + check_string( + args$weight_func, + allow_null = TRUE, + call = call, + arg = "weight_func" + ) + invisible(object) } @@ -144,20 +158,28 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) { #' @export multi_predict._train.kknn <- function(object, new_data, type = NULL, neighbors = NULL, ...) { - if (is.null(neighbors)) + if (is.null(neighbors)) { neighbors <- rlang::eval_tidy(object$fit$call$ks) + } neighbors <- sort(neighbors) if (is.null(type)) { - if (object$spec$mode == "classification") + if (object$spec$mode == "classification") { type <- "class" - else + } else { type <- "numeric" + } } res <- - purrr::map(neighbors, knn_by_k, object = object, - new_data = new_data, type = type, ...) |> + purrr::map( + neighbors, + knn_by_k, + object = object, + new_data = new_data, + type = type, + ... + ) |> purrr::list_rbind() res <- dplyr::arrange(res, .row, neighbors) res <- split(res[, -1], res$.row) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index a24051f61..264465785 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -1,4 +1,3 @@ - set_new_model("nearest_neighbor") set_model_mode("nearest_neighbor", "classification") @@ -31,7 +30,7 @@ set_model_arg( eng = "kknn", parsnip = "dist_power", original = "distance", - func = list(pkg = "dials", fun = "dist_power", range = c(1/10, 2)), + func = list(pkg = "dials", fun = "dist_power", range = c(1 / 10, 2)), has_submodel = FALSE ) @@ -88,26 +87,27 @@ set_pred( eng = "kknn", mode = "regression", type = "numeric", - value = list( + value = list( # seems unnecessary here as the predict_numeric catches it based on the # model mode pre = function(x, object) { if (object$fit$response != "continuous") { cli::cli_abort( - c("`kknn` model does not appear to use numeric predictions.", - "i" = "Was the model fit with a continuous response variable?") + c( + "`kknn` model does not appear to use numeric predictions.", + "i" = "Was the model fit with a continuous response variable?" + ) ) } x }, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) ) @@ -120,11 +120,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) @@ -137,20 +136,21 @@ set_pred( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { cli::cli_abort( - c("`kknn` model does not appear to use class predictions.", - "i" = "Was the model fit with a factor response variable?") + c( + "`kknn` model does not appear to use class predictions.", + "i" = "Was the model fit with a factor response variable?" + ) ) } x }, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) ) @@ -163,20 +163,21 @@ set_pred( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { cli::cli_abort( - c("`kknn` model does not appear to use class predictions.", - "i" = "Was the model fit with a factor response variable?") + c( + "`kknn` model does not appear to use class predictions.", + "i" = "Was the model fit with a factor response variable?" + ) ) } x }, post = function(result, object) as_tibble(result), func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) ) ) @@ -189,10 +190,9 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) diff --git a/R/nullmodel.R b/R/nullmodel.R index 36deaaf72..34a431bb7 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -43,22 +43,20 @@ #' predict(useless, matrix(NA, nrow = 5)) #' #' @export -nullmodel <- function (x, ...) UseMethod("nullmodel") +nullmodel <- function(x, ...) UseMethod("nullmodel") #' @export #' @rdname nullmodel nullmodel.default <- function(x = NULL, y, ...) { - - - if(is.factor(y)) { + if (is.factor(y)) { lvls <- levels(y) tab <- table(y) value <- names(tab)[which.max(tab)] - pct <- tab/sum(tab) + pct <- tab / sum(tab) } else { lvls <- NULL pct <- NULL - if(is.null(dim(y))) { + if (is.null(dim(y))) { value <- mean(y, na.rm = TRUE) } else { value <- colMeans(y, na.rm = TRUE) @@ -66,44 +64,48 @@ nullmodel.default <- function(x = NULL, y, ...) { } structure( - list(call = match.call(), - value = value, - levels = lvls, - pct = pct, - n = length(y[[1]])), - class = "nullmodel") + list( + call = match.call(), + value = value, + levels = lvls, + pct = pct, + n = length(y[[1]]) + ), + class = "nullmodel" + ) } #' @export #' @rdname nullmodel print.nullmodel <- function(x, ...) { - cat("Null", - ifelse(is.null(x$levels), "Classification", "Regression"), - "Model\n") + cat( + "Null", + ifelse(is.null(x$levels), "Classification", "Regression"), + "Model\n" + ) x$call if (length(x$value) == 1) { - cat("Predicted Value:", - ifelse(is.null(x$levels), format(x$value), x$value), - "\n") + cat( + "Predicted Value:", + ifelse(is.null(x$levels), format(x$value), x$value), + "\n" + ) } else { - cat("Predicted Value:\n", - names(x$value), "\n", - x$value, - "\n") + cat("Predicted Value:\n", names(x$value), "\n", x$value, "\n") } } #' @export #' @rdname nullmodel -predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { - if(is.null(type)) { - type <- if(is.null(object$levels)) "raw" else "class" +predict.nullmodel <- function(object, new_data = NULL, type = NULL, ...) { + if (is.null(type)) { + type <- if (is.null(object$levels)) "raw" else "class" } - n <- if(is.null(new_data)) object$n else nrow(new_data) - if(!is.null(object$levels)) { - if(type == "prob") { + n <- if (is.null(new_data)) object$n else nrow(new_data) + if (!is.null(object$levels)) { + if (type == "prob") { out <- matrix(rep(object$pct, n), nrow = n, byrow = TRUE) colnames(out) <- object$levels out <- as.data.frame(out) @@ -112,12 +114,18 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { } } else { if (type %in% c("prob", "class")) { - cli::cli_abort("Only numeric predicitons are applicable to regression models.") + cli::cli_abort( + "Only numeric predicitons are applicable to regression models." + ) } if (length(object$value) == 1) { out <- rep(object$value, n) } else { - out <- matrix(rep(object$value, n), ncol = length(object$value), byrow = TRUE) + out <- matrix( + rep(object$value, n), + ncol = length(object$value), + byrow = TRUE + ) colnames(out) <- names(object$value) out <- as_tibble(out) } @@ -165,8 +173,7 @@ null_model <- engine = engine, user_specified_engine = !missing(engine) ) -} - + } #' Tidy method for null models @@ -185,4 +192,3 @@ null_model <- tidy.nullmodel <- function(x, ...) { tibble::tibble(value = x$value) } - diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index 42aa52c7a..4f6f0112c 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -66,12 +66,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) ) ) @@ -84,12 +83,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) ) ) @@ -102,12 +100,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) ) ) @@ -122,12 +119,11 @@ set_pred( as_tibble(x) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) ) ) @@ -140,12 +136,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "raw" - ) + args = list( + object = quote(object$fit), + new_data = quote(new_data), + type = "raw" + ) ) ) - diff --git a/R/parsnip-package.R b/R/parsnip-package.R index 01f1f42c1..e0a429395 100644 --- a/R/parsnip-package.R +++ b/R/parsnip-package.R @@ -34,15 +34,60 @@ NULL utils::globalVariables( c( - '.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', - 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type', - "neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty", - "max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees", - "sub_neighbors", ".pred_class", "x", "y", "predictor_indicators", - "compute_intercept", "remove_intercept", "estimate", "term", - "call_info", "component", "component_id", "func", "tunable", "label", - "pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", - "protect", "weight_time", ".prob_cens", ".weight_cens", "s" + '.', + '.label', + '.pred', + '.row', + 'data', + 'engine', + 'engine2', + 'group', + 'lab', + 'original', + 'predicted_label', + 'prediction', + 'value', + 'type', + "neighbors", + ".submodels", + "has_submodel", + "max_neighbor", + "max_penalty", + "max_terms", + "max_tree", + "model", + "name", + "num_terms", + "penalty", + "trees", + "sub_neighbors", + ".pred_class", + "x", + "y", + "predictor_indicators", + "compute_intercept", + "remove_intercept", + "estimate", + "term", + "call_info", + "component", + "component_id", + "func", + "tunable", + "label", + "pkg", + ".order", + "item", + "tunable", + "has_ext", + "id", + "weights", + "has_wts", + "protect", + "weight_time", + ".prob_cens", + ".weight_cens", + "s" ) ) @@ -55,4 +100,3 @@ release_bullets <- function() { } # nocov end - diff --git a/R/partykit.R b/R/partykit.R index a1ca20da5..0bf3c4be9 100644 --- a/R/partykit.R +++ b/R/partykit.R @@ -34,15 +34,17 @@ #' } #' @export ctree_train <- - function(formula, - data, - weights = NULL, - minsplit = 20L, - maxdepth = Inf, - teststat = "quadratic", - testtype = "Bonferroni", - mincriterion = 0.95, - ...) { + function( + formula, + data, + weights = NULL, + minsplit = 20L, + maxdepth = Inf, + teststat = "quadratic", + testtype = "Bonferroni", + mincriterion = 0.95, + ... + ) { rlang::check_installed("partykit") opts <- rlang::list2(...) @@ -76,9 +78,11 @@ ctree_train <- !!!opts ) if (!is.null(weights)) { - if (!is.vector(weights) || + if ( + !is.vector(weights) || !is.integer(weights) || - length(weights) != nrow(data)) { + length(weights) != nrow(data) + ) { cli::cli_abort( "{.arg weights} should be an integer vector with size the same as the number of rows of {.arg data}." @@ -103,22 +107,24 @@ ctree_train <- #' @rdname ctree_train #' @export cforest_train <- - function(formula, - data, - weights = NULL, - minsplit = 20L, - maxdepth = Inf, - teststat = "quadratic", - testtype = "Univariate", - mincriterion = 0, - mtry = ceiling(sqrt(ncol(data) - 1)), - ntree = 500L, - ...) { + function( + formula, + data, + weights = NULL, + minsplit = 20L, + maxdepth = Inf, + teststat = "quadratic", + testtype = "Univariate", + mincriterion = 0, + mtry = ceiling(sqrt(ncol(data) - 1)), + ntree = 500L, + ... + ) { rlang::check_installed("partykit") force(mtry) opts <- rlang::list2(...) - mtry <- max_mtry_formula(mtry, formula, data) + mtry <- max_mtry_formula(mtry, formula, data) minsplit <- min(minsplit, nrow(data)) if (any(names(opts) == "control")) { @@ -156,9 +162,11 @@ cforest_train <- ) if (!is.null(weights)) { - if (!is.vector(weights) || + if ( + !is.vector(weights) || !is.numeric(weights) || - length(weights) != nrow(data)) { + length(weights) != nrow(data) + ) { cli::cli_abort( "{.arg weights} should be a numeric vector with size the same as the number of rows of {.arg data}." diff --git a/R/pls.R b/R/pls.R index 3a2bc13d7..dfca73172 100644 --- a/R/pls.R +++ b/R/pls.R @@ -24,11 +24,15 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("pls")} #' @export pls <- - function(mode = "unknown", predictor_prop = NULL, num_comp = NULL, engine = "mixOmics") { - + function( + mode = "unknown", + predictor_prop = NULL, + num_comp = NULL, + engine = "mixOmics" + ) { args <- list( predictor_prop = enquo(predictor_prop), - num_comp = enquo(num_comp) + num_comp = enquo(num_comp) ) new_model_spec( @@ -64,14 +68,17 @@ pls <- #' @rdname parsnip_update #' @export update.pls <- - function(object, - parameters = NULL, - predictor_prop = NULL, num_comp = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + predictor_prop = NULL, + num_comp = NULL, + fresh = FALSE, + ... + ) { args <- list( - predictor_prop = enquo(predictor_prop), - num_comp = enquo(num_comp) + predictor_prop = enquo(predictor_prop), + num_comp = enquo(num_comp) ) update_spec( @@ -88,10 +95,15 @@ update.pls <- #' @export check_args.pls <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_whole(args$num_comp, min = 0, allow_null = TRUE, call = call, arg = "num_comp") + check_number_whole( + args$num_comp, + min = 0, + allow_null = TRUE, + call = call, + arg = "num_comp" + ) invisible(object) } diff --git a/R/poisson_reg.R b/R/poisson_reg.R index e3201aca1..74db19b67 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -32,11 +32,12 @@ #' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("poisson_reg")} #' @export poisson_reg <- - function(mode = "regression", - penalty = NULL, - mixture = NULL, - engine = "glm") { - + function( + mode = "regression", + penalty = NULL, + mixture = NULL, + engine = "glm" + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -60,11 +61,14 @@ poisson_reg <- #' @rdname parsnip_update #' @export update.poisson_reg <- - function(object, - parameters = NULL, - penalty = NULL, mixture = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... + ) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -102,11 +106,23 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) { #' @export check_args.poisson_reg <- function(object, call = rlang::caller_env()) { - args <- lapply(object$args, rlang::eval_tidy) - check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal( + args$mixture, + min = 0, + max = 1, + allow_null = TRUE, + call = call, + arg = "mixture" + ) + check_number_decimal( + args$penalty, + min = 0, + allow_null = TRUE, + call = call, + arg = "penalty" + ) invisible(object) } diff --git a/R/predict.R b/R/predict.R index da62a7b7b..ef2a49635 100644 --- a/R/predict.R +++ b/R/predict.R @@ -145,7 +145,13 @@ #' @method predict model_fit #' @export predict.model_fit #' @export -predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) { +predict.model_fit <- function( + object, + new_data, + type = NULL, + opts = list(), + ... +) { if (inherits(object$fit, "try-error")) { cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) @@ -156,7 +162,9 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) { - cli::cli_warn("{.arg opts} is only used with `type = 'raw'` and was ignored.") + cli::cli_warn( + "{.arg opts} is only used with `type = 'raw'` and was ignored." + ) } check_pred_type_dots(object, type, ...) @@ -164,28 +172,32 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) res <- switch( type, - numeric = predict_numeric(object = object, new_data = new_data, ...), - class = predict_class(object = object, new_data = new_data, ...), - prob = predict_classprob(object = object, new_data = new_data, ...), - conf_int = predict_confint(object = object, new_data = new_data, ...), - pred_int = predict_predint(object = object, new_data = new_data, ...), - quantile = predict_quantile(object = object, new_data = new_data, ...), - time = predict_time(object = object, new_data = new_data, ...), - survival = predict_survival(object = object, new_data = new_data, ...), - linear_pred = predict_linear_pred(object = object, new_data = new_data, ...), - hazard = predict_hazard(object = object, new_data = new_data, ...), - raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), + numeric = predict_numeric(object = object, new_data = new_data, ...), + class = predict_class(object = object, new_data = new_data, ...), + prob = predict_classprob(object = object, new_data = new_data, ...), + conf_int = predict_confint(object = object, new_data = new_data, ...), + pred_int = predict_predint(object = object, new_data = new_data, ...), + quantile = predict_quantile(object = object, new_data = new_data, ...), + time = predict_time(object = object, new_data = new_data, ...), + survival = predict_survival(object = object, new_data = new_data, ...), + linear_pred = predict_linear_pred( + object = object, + new_data = new_data, + ... + ), + hazard = predict_hazard(object = object, new_data = new_data, ...), + raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), cli::cli_abort("Unknown prediction {.arg type} '{type}'.") ) if (!inherits(res, "tbl_spark")) { res <- switch( type, - numeric = format_num(res), - class = format_class(res), - prob = format_classprobs(res), - time = format_time(res), - survival = format_survival(res), - hazard = format_hazard(res), + numeric = format_num(res), + class = format_class(res), + prob = format_classprobs(res), + time = format_time(res), + survival = format_survival(res), + hazard = format_hazard(res), linear_pred = format_linear_pred(res), res ) @@ -209,11 +221,12 @@ check_pred_type <- function(object, type, ..., call = rlang::caller_env()) { ) } - if (!(type %in% pred_types)) + if (!(type %in% pred_types)) { cli::cli_abort( "{.arg type} should be one of {.or {.arg {pred_types}}}.", call = call ) + } switch( type, @@ -326,7 +339,7 @@ format_survival <- function(x) { #' @rdname format-internals #' @export format_linear_pred <- function(x) { - if (inherits(x, "tbl_spark")){ + if (inherits(x, "tbl_spark")) { return(x) } ensure_parsnip_format(x, ".pred_linear_pred") @@ -352,8 +365,10 @@ ensure_parsnip_format <- function(x, col_name, overwrite = TRUE) { } } } else { - x <- tibble::new_tibble(vctrs::df_list(unname(x), .name_repair = "minimal"), - nrow = length(x)) + x <- tibble::new_tibble( + vctrs::df_list(unname(x), .name_repair = "minimal"), + nrow = length(x) + ) names(x) <- col_name x } @@ -361,16 +376,22 @@ ensure_parsnip_format <- function(x, col_name, overwrite = TRUE) { } make_pred_call <- function(x) { - if ("pkg" %in% names(x$func)) + if ("pkg" %in% names(x$func)) { cl <- - call2(x$func["fun"],!!!x$args, .ns = x$func["pkg"]) - else - cl <- call2(x$func["fun"],!!!x$args) + call2(x$func["fun"], !!!x$args, .ns = x$func["pkg"]) + } else { + cl <- call2(x$func["fun"], !!!x$args) + } cl } -check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) { +check_pred_type_dots <- function( + object, + type, + ..., + call = rlang::caller_env() +) { the_dots <- list(...) nms <- names(the_dots) @@ -380,8 +401,15 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) # ---------------------------------------------------------------------------- - other_args <- c("interval", "level", "std_error", "quantile_levels", - "time", "eval_time", "increasing") + other_args <- c( + "interval", + "level", + "std_error", + "quantile_levels", + "time", + "eval_time", + "increasing" + ) eval_time_types <- c("survival", "hazard") @@ -390,9 +418,9 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) bad_args <- names(the_dots)[!is_pred_arg] bad_args <- paste0("`", bad_args, "`", collapse = ", ") cli::cli_abort( - "The ellipses are not used to pass args to the model function's + "The ellipses are not used to pass args to the model function's predict function. These arguments cannot be used: {.val bad_args}", - call = call + call = call ) } @@ -402,10 +430,8 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) cli::cli_abort( "{.arg eval_time} should only be passed to {.fn predict} when \\ {.arg type} is one of {.or {.val {eval_time_types}}}.", - call = call - ) - - + call = call + ) } if (any(nms == "time") & !type %in% c("survival", "hazard")) { cli::cli_abort( @@ -415,24 +441,27 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) ) } # when eval_time should be passed - if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) { - cli::cli_abort( - "When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric + if ( + !any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard") + ) { + cli::cli_abort( + "When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric vector {.arg eval_time} should also be given.", - call = call - ) + call = call + ) } # `increasing` only applies to linear_pred for censored regression - if (any(nms == "increasing") & + if ( + any(nms == "increasing") & !(type == "linear_pred" & - object$spec$mode == "censored regression")) { + object$spec$mode == "censored regression") + ) { cli::cli_abort( "{.arg increasing} only applies to predictions of type 'linear_pred' for the mode censored regression.", call = call ) - } invisible(TRUE) @@ -491,4 +520,3 @@ prepare_data <- function(object, new_data) { new_data ) } - diff --git a/R/predict_class.R b/R/predict_class.R index 98d1adff0..ce3aafc56 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -10,15 +10,19 @@ #' @export predict_class.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") { - cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.", - call = rlang::call2("predict")) + cli::cli_abort( + "{.fun predict.model_fit} is for predicting factor outcomes.", + call = rlang::call2("predict") + ) } check_spec_pred_type(object, "class") if (inherits(object$fit, "try-error")) { - cli::cli_warn("Model fit failed; cannot make predictions.", - call = rlang::call2("predict")) + cli::cli_warn( + "Model fit failed; cannot make predictions.", + call = rlang::call2("predict") + ) return(NULL) } @@ -41,16 +45,22 @@ predict_class.model_fit <- function(object, new_data, ...) { # coerce levels to those in `object` if (is.vector(res) || is.factor(res)) { - res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered) + res <- factor( + as.character(res), + levels = object$lvl, + ordered = object$ordered + ) } else { if (!inherits(res, "tbl_spark")) { # Now case where a parsnip model generated `res` if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) { res <- res[[1]] } else { - res$values <- factor(as.character(res$values), - levels = object$lvl, - ordered = object$ordered) + res$values <- factor( + as.character(res$values), + levels = object$lvl, + ordered = object$ordered + ) } } } diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 7642ae3d3..19b1ce0d3 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -6,8 +6,10 @@ #' @export predict_classprob.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") { - cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.", - call = rlang::call2("predict")) + cli::cli_abort( + "{.fun predict.model_fit()} is for predicting factor outcomes.", + call = rlang::call2("predict") + ) } check_spec_pred_type(object, "prob", call = caller_env()) @@ -37,8 +39,10 @@ predict_classprob.model_fit <- function(object, new_data, ...) { # check and sort names if (!is.data.frame(res) & !inherits(res, "tbl_spark")) { - cli::cli_abort("The was a problem with the probability predictions.", - call = rlang::call2("predict")) + cli::cli_abort( + "The was a problem with the probability predictions.", + call = rlang::call2("predict") + ) } if (!is_tibble(res) & !inherits(res, "tbl_spark")) { diff --git a/R/predict_hazard.R b/R/predict_hazard.R index a92fb454d..e49867106 100644 --- a/R/predict_hazard.R +++ b/R/predict_hazard.R @@ -4,11 +4,13 @@ #' @method predict_hazard model_fit #' @export predict_hazard.model_fit #' @export -predict_hazard.model_fit <- function(object, - new_data, - eval_time, - time = deprecated(), - ...) { +predict_hazard.model_fit <- function( + object, + new_data, + eval_time, + time = deprecated(), + ... +) { if (lifecycle::is_present(time)) { lifecycle::deprecate_warn( "1.0.4.9005", @@ -29,8 +31,9 @@ predict_hazard.model_fit <- function(object, new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$hazard$pre)) + if (!is.null(object$spec$method$pred$hazard$pre)) { new_data <- object$spec$method$pred$hazard$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$hazard) @@ -38,7 +41,7 @@ predict_hazard.model_fit <- function(object, res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$pred$hazard$post)) { + if (!is.null(object$spec$method$pred$hazard$post)) { res <- object$spec$method$pred$hazard$post(res, object) } @@ -49,5 +52,6 @@ predict_hazard.model_fit <- function(object, # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_hazard <- function (object, ...) +predict_hazard <- function(object, ...) { UseMethod("predict_hazard") +} diff --git a/R/predict_interval.R b/R/predict_interval.R index 136735a94..0552dc2a5 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -8,8 +8,13 @@ #' @method predict_confint model_fit #' @export predict_confint.model_fit #' @export -predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - +predict_confint.model_fit <- function( + object, + new_data, + level = 0.95, + std_error = FALSE, + ... +) { check_spec_pred_type(object, "conf_int") if (inherits(object$fit, "try-error")) { @@ -20,8 +25,9 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$conf_int$pre)) + if (!is.null(object$spec$method$pred$conf_int$pre)) { new_data <- object$spec$method$pred$conf_int$pre(new_data, object) + } # Pass some extra arguments to be used in post-processor object$spec$method$pred$conf_int$extras <- @@ -44,8 +50,9 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error #' @keywords internal #' @rdname other_predict #' @inheritParams predict.model_fit -predict_confint <- function(object, ...) +predict_confint <- function(object, ...) { UseMethod("predict_confint") +} # ------------------------------------------------------------------------------ @@ -53,8 +60,9 @@ predict_confint <- function(object, ...) #' @keywords internal #' @rdname other_predict #' @inheritParams predict.model_fit -predict_predint <- function(object, ...) +predict_predint <- function(object, ...) { UseMethod("predict_predint") +} #' @keywords internal @@ -63,8 +71,13 @@ predict_predint <- function(object, ...) #' @method predict_predint model_fit #' @export predict_predint.model_fit #' @export -predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - +predict_predint.model_fit <- function( + object, + new_data, + level = 0.95, + std_error = FALSE, + ... +) { check_spec_pred_type(object, "pred_int") if (inherits(object$fit, "try-error")) { @@ -75,8 +88,9 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$pred_int$pre)) + if (!is.null(object$spec$method$pred$pred_int$pre)) { new_data <- object$spec$method$pred$pred_int$pre(new_data, object) + } # create prediction call # Pass some extra arguments to be used in post-processor @@ -100,6 +114,6 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error #' @keywords internal #' @rdname other_predict #' @inheritParams predict.model_fit -predict_predint <- function(object, ...) +predict_predint <- function(object, ...) { UseMethod("predict_predint") - +} diff --git a/R/predict_linear_pred.R b/R/predict_linear_pred.R index 927164578..08e5ae585 100644 --- a/R/predict_linear_pred.R +++ b/R/predict_linear_pred.R @@ -5,7 +5,6 @@ #' @export predict_linear_pred.model_fit #' @export predict_linear_pred.model_fit <- function(object, new_data, ...) { - check_spec_pred_type(object, "linear_pred") if (inherits(object$fit, "try-error")) { @@ -16,8 +15,9 @@ predict_linear_pred.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$linear_pred$pre)) + if (!is.null(object$spec$method$pred$linear_pred$pre)) { new_data <- object$spec$method$pred$linear_pred$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$linear_pred) @@ -41,5 +41,6 @@ predict_linear_pred.model_fit <- function(object, new_data, ...) { #' @keywords internal #' @rdname other_predict #' @inheritParams predict_linear_pred.model_fit -predict_linear_pred <- function(object, ...) +predict_linear_pred <- function(object, ...) { UseMethod("predict_linear_pred") +} diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 54d6461ff..74d1c8210 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -9,13 +9,15 @@ #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export -predict_quantile.model_fit <- function(object, - new_data, - quantile_levels = NULL, - quantile = deprecated(), - interval = "none", - level = 0.95, - ...) { +predict_quantile.model_fit <- function( + object, + new_data, + quantile_levels = NULL, + quantile = deprecated(), + interval = "none", + level = 0.95, + ... +) { check_dots_empty() check_spec_pred_type(object, "quantile") @@ -28,7 +30,6 @@ predict_quantile.model_fit <- function(object, quantile_levels <- quantile } - if (inherits(object$fit, "try-error")) { cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) @@ -36,13 +37,15 @@ predict_quantile.model_fit <- function(object, if (object$spec$mode == "quantile regression") { if (!is.null(quantile_levels)) { - cli::cli_abort("When the mode is {.val quantile regression}, + cli::cli_abort( + "When the mode is {.val quantile regression}, {.arg quantile_levels} are specified by {.fn set_mode}.", - call = rlang::call2("predict")) + call = rlang::call2("predict") + ) } } else { if (is.null(quantile_levels)) { - quantile_levels <- (1:9)/10 + quantile_levels <- (1:9) / 10 } hardhat::check_quantile_levels(quantile_levels) # Pass some extra arguments to be used in post-processor @@ -61,7 +64,7 @@ predict_quantile.model_fit <- function(object, res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$pred$quantile$post)) { + if (!is.null(object$spec$method$pred$quantile$post)) { res <- object$spec$method$pred$quantile$post(res, object) } @@ -72,6 +75,6 @@ predict_quantile.model_fit <- function(object, # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_quantile <- function (object, ...) { +predict_quantile <- function(object, ...) { UseMethod("predict_quantile") } diff --git a/R/predict_raw.R b/R/predict_raw.R index 02d24ffeb..573d319f0 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -23,8 +23,9 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$raw$pre)) + if (!is.null(object$spec$method$pred$raw$pre)) { new_data <- object$spec$method$pred$raw$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$raw) @@ -38,5 +39,6 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { #' @export #' @keywords internal #' @rdname predict.model_fit -predict_raw <- function(object, ...) +predict_raw <- function(object, ...) { UseMethod("predict_raw") +} diff --git a/R/predict_survival.R b/R/predict_survival.R index 9aa99e483..393b21273 100644 --- a/R/predict_survival.R +++ b/R/predict_survival.R @@ -4,13 +4,15 @@ #' @method predict_survival model_fit #' @export predict_survival.model_fit #' @export -predict_survival.model_fit <- function(object, - new_data, - eval_time, - time = deprecated(), - interval = "none", - level = 0.95, - ...) { +predict_survival.model_fit <- function( + object, + new_data, + eval_time, + time = deprecated(), + interval = "none", + level = 0.95, + ... +) { if (lifecycle::is_present(time)) { lifecycle::deprecate_warn( "1.0.4.9005", @@ -31,8 +33,9 @@ predict_survival.model_fit <- function(object, new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$survival$pre)) + if (!is.null(object$spec$method$pred$survival$pre)) { new_data <- object$spec$method$pred$survival$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$survival) @@ -40,7 +43,7 @@ predict_survival.model_fit <- function(object, res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$pred$survival$post)) { + if (!is.null(object$spec$method$pred$survival$post)) { res <- object$spec$method$pred$survival$post(res, object) } @@ -51,5 +54,6 @@ predict_survival.model_fit <- function(object, #' @keywords internal #' @rdname other_predict #' @inheritParams predict_survival.model_fit -predict_survival <- function (object, ...) +predict_survival <- function(object, ...) { UseMethod("predict_survival") +} diff --git a/R/print.R b/R/print.R index d03dcfcdb..467c7388f 100644 --- a/R/print.R +++ b/R/print.R @@ -8,9 +8,17 @@ print.model_spec <- function(x, ...) { #' @keywords internal #' @rdname add_on_exports #' @export -print_model_spec <- function(x, cls = class(x)[1], desc = get_model_desc(cls), ...) { +print_model_spec <- function( + x, + cls = class(x)[1], + desc = get_model_desc(cls), + ... +) { if (!spec_is_loaded(spec = structure(x, class = cls))) { - prompt_missing_implementation(spec = structure(x, class = cls), prompt = cli::cli_inform) + prompt_missing_implementation( + spec = structure(x, class = cls), + prompt = cli::cli_inform + ) } mode <- switch(x$mode, unknown = "unknown mode", x$mode) @@ -41,39 +49,39 @@ get_model_desc <- function(cls) { } model_descs <- tibble::tribble( - ~cls, ~desc, - "auto_ml", "Automatic Machine Learning", - "bag_mars", "Bagged MARS", - "bag_mlp", "Bagged Neural Network", - "bag_tree", "Bagged Decision Tree", - "bart", "BART", - "boost_tree", "Boosted Tree", - "C5_rules", "C5.0", - "cubist_rules", "Cubist", - "decision_tree", "Decision Tree", - "discrim_flexible", "Flexible Discriminant", - "discrim_linear", "Linear Discriminant", - "discrim_quad", "Quadratic Discriminant", - "discrim_regularized", "Regularized Discriminant", - "gen_additive_mod", "GAM", - "linear_reg", "Linear Regression", - "logistic_reg", "Logistic Regression", - "mars", "MARS", - "mlp", "Single Layer Neural Network", - "multinom_reg", "Multinomial Regression", - "naive_Bayes", "Naive Bayes", - "nearest_neighbor", "K-Nearest Neighbor", - "null_model", "Null", - "pls", "PLS", - "poisson_reg", "Poisson Regression", - "proportional_hazards", "Proportional Hazards", - "rand_forest", "Random Forest", - "rule_fit", "RuleFit", - "surv_reg", "Parametric Survival Regression", - "survival_reg", "Parametric Survival Regression", - "svm_linear", "Linear Support Vector Machine", - "svm_poly", "Polynomial Support Vector Machine", - "svm_rbf", "Radial Basis Function Support Vector Machine" + ~cls , ~desc , + "auto_ml" , "Automatic Machine Learning" , + "bag_mars" , "Bagged MARS" , + "bag_mlp" , "Bagged Neural Network" , + "bag_tree" , "Bagged Decision Tree" , + "bart" , "BART" , + "boost_tree" , "Boosted Tree" , + "C5_rules" , "C5.0" , + "cubist_rules" , "Cubist" , + "decision_tree" , "Decision Tree" , + "discrim_flexible" , "Flexible Discriminant" , + "discrim_linear" , "Linear Discriminant" , + "discrim_quad" , "Quadratic Discriminant" , + "discrim_regularized" , "Regularized Discriminant" , + "gen_additive_mod" , "GAM" , + "linear_reg" , "Linear Regression" , + "logistic_reg" , "Logistic Regression" , + "mars" , "MARS" , + "mlp" , "Single Layer Neural Network" , + "multinom_reg" , "Multinomial Regression" , + "naive_Bayes" , "Naive Bayes" , + "nearest_neighbor" , "K-Nearest Neighbor" , + "null_model" , "Null" , + "pls" , "PLS" , + "poisson_reg" , "Poisson Regression" , + "proportional_hazards" , "Proportional Hazards" , + "rand_forest" , "Random Forest" , + "rule_fit" , "RuleFit" , + "surv_reg" , "Parametric Survival Regression" , + "survival_reg" , "Parametric Survival Regression" , + "svm_linear" , "Linear Support Vector Machine" , + "svm_poly" , "Polynomial Support Vector Machine" , + "svm_rbf" , "Radial Basis Function Support Vector Machine" ) #' Print helper for model objects @@ -103,10 +111,11 @@ model_printer <- function(x, ...) { cat("Fit function:\n") print(x$method$fit_call) if (length(x$method$libs) > 0) { - if (length(x$method$libs) > 1) + if (length(x$method$libs) > 1) { cat("\nRequired packages:\n") - else + } else { cat("\nRequired package: ") + } cat(paste0(x$method$libs, collapse = ", "), "\n") } } @@ -116,7 +125,7 @@ model_printer <- function(x, ...) { print_arg_list <- function(x, ...) { atomic <- vapply(x, is.atomic, logical(1)) x2 <- x - x2[!atomic] <- lapply(x2[!atomic], deparserizer, ...) + x2[!atomic] <- lapply(x2[!atomic], deparserizer, ...) res <- paste0(" ", names(x2), " = ", x2, collaspe = "\n") cat(res, sep = "") } @@ -125,7 +134,8 @@ deparserizer <- function(x, limit = options()$width - 10) { x <- deparse(x, width.cutoff = limit) x <- gsub("^ ", "", x) x <- paste0(x, collapse = "") - if (nchar(x) > limit) + if (nchar(x) > limit) { x <- paste0(substring(x, first = 1, last = limit - 7), "") + } x } diff --git a/R/proportional_hazards.R b/R/proportional_hazards.R index 038e19c94..46bd1a9e4 100644 --- a/R/proportional_hazards.R +++ b/R/proportional_hazards.R @@ -37,50 +37,52 @@ proportional_hazards <- function( mode = "censored regression", engine = "survival", penalty = NULL, - mixture = NULL) { + mixture = NULL +) { + args <- list( + penalty = enquo(penalty), + mixture = enquo(mixture) + ) - args <- list( - penalty = enquo(penalty), - mixture = enquo(mixture) - ) - - new_model_spec( - "proportional_hazards", - args = args, - eng_args = NULL, - mode = mode, - user_specified_mode = !missing(mode), - method = NULL, - engine = engine, - user_specified_engine = !missing(engine) - ) - } + new_model_spec( + "proportional_hazards", + args = args, + eng_args = NULL, + mode = mode, + user_specified_mode = !missing(mode), + method = NULL, + engine = engine, + user_specified_engine = !missing(engine) + ) +} # ------------------------------------------------------------------------------ #' @method update proportional_hazards #' @rdname parsnip_update #' @export -update.proportional_hazards <- function(object, - parameters = NULL, - penalty = NULL, - mixture = NULL, - fresh = FALSE, ...) { - - args <- list( - penalty = enquo(penalty), - mixture = enquo(mixture) - ) +update.proportional_hazards <- function( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... +) { + args <- list( + penalty = enquo(penalty), + mixture = enquo(mixture) + ) - update_spec( - object = object, - parameters = parameters, - args_enquo_list = args, - fresh = fresh, - cls = "proportional_hazards", - ... - ) - } + update_spec( + object = object, + parameters = parameters, + args_enquo_list = args, + fresh = fresh, + cls = "proportional_hazards", + ... + ) +} #' @export translate.proportional_hazards <- function(x, engine = x$engine, ...) { diff --git a/R/proportional_hazards_data.R b/R/proportional_hazards_data.R index 9514b18d9..3a765fac5 100644 --- a/R/proportional_hazards_data.R +++ b/R/proportional_hazards_data.R @@ -1,4 +1,3 @@ - # parsnip just contains the model specification, the engines are the censored package. set_new_model("proportional_hazards") diff --git a/R/repair_call.R b/R/repair_call.R index 81d1005b1..0e0a38817 100644 --- a/R/repair_call.R +++ b/R/repair_call.R @@ -39,7 +39,7 @@ repair_call <- function(x, data) { needs_eval <- purrr::map_lgl(fit_call, rlang::is_quosure) if (any(needs_eval)) { eval_args <- names(needs_eval)[needs_eval] - for(arg in eval_args) { + for (arg in eval_args) { fit_call[[arg]] <- rlang::eval_tidy(fit_call[[arg]]) } } diff --git a/R/rule_fit.R b/R/rule_fit.R index 81a6ce48d..3775831eb 100644 --- a/R/rule_fit.R +++ b/R/rule_fit.R @@ -37,15 +37,19 @@ #' #' @export rule_fit <- - function(mode = "unknown", - mtry = NULL, trees = NULL, min_n = NULL, - tree_depth = NULL, learn_rate = NULL, - loss_reduction = NULL, - sample_size = NULL, - stop_iter = NULL, - penalty = NULL, - engine = "xrf") { - + function( + mode = "unknown", + mtry = NULL, + trees = NULL, + min_n = NULL, + tree_depth = NULL, + learn_rate = NULL, + loss_reduction = NULL, + sample_size = NULL, + stop_iter = NULL, + penalty = NULL, + engine = "xrf" + ) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -58,7 +62,6 @@ rule_fit <- penalty = enquo(penalty) ) - new_model_spec( "rule_fit", args = args, @@ -87,14 +90,20 @@ rule_fit <- #' @inheritParams rule_fit #' @export update.rule_fit <- - function(object, - parameters = NULL, - mtry = NULL, trees = NULL, min_n = NULL, - tree_depth = NULL, learn_rate = NULL, - loss_reduction = NULL, sample_size = NULL, - penalty = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + mtry = NULL, + trees = NULL, + min_n = NULL, + tree_depth = NULL, + learn_rate = NULL, + loss_reduction = NULL, + sample_size = NULL, + penalty = NULL, + fresh = FALSE, + ... + ) { args <- list( mtry = enquo(mtry), trees = enquo(trees), diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index baf485e64..4f9a2bd81 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,18 +1,18 @@ #' Using sparse data with parsnip -#' -#' You can figure out whether a given model engine supports sparse data by +#' +#' You can figure out whether a given model engine supports sparse data by #' calling `get_encoding("name of model")` and looking at the `allow_sparse_x` #' column. -#' -#' Using sparse data for model fitting and prediction shouldn't require any -#' additional configurations. Just pass in a sparse matrix such as dgCMatrix -#' from the `Matrix` package or a sparse tibble from the sparsevctrs package +#' +#' Using sparse data for model fitting and prediction shouldn't require any +#' additional configurations. Just pass in a sparse matrix such as dgCMatrix +#' from the `Matrix` package or a sparse tibble from the sparsevctrs package #' to the data argument of [fit()], [fit_xy()], and [predict()]. -#' -#' Models that don't support sparse data will try to convert to non-sparse data -#' with warnings. If conversion isn’t possible, an informative error will be +#' +#' Models that don't support sparse data will try to convert to non-sparse data +#' with warnings. If conversion isn’t possible, an informative error will be #' thrown. -#' +#' #' @name sparse_data NULL @@ -24,7 +24,7 @@ to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) { if (inherits(object, "model_fit")) { object <- object$spec } - + cli::cli_abort( "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with engine {.val {object$engine}} doesn't accept that.", @@ -46,7 +46,7 @@ materialize_sparse_tibble <- function(x, object, input) { if (inherits(object, "model_fit")) { object <- object$spec } - + cli::cli_warn( "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with engine {.val {object$engine}} doesn't accept that. Converting to diff --git a/R/standalone-survival.R b/R/standalone-survival.R index 399e2e11c..3f3809cc6 100644 --- a/R/standalone-survival.R +++ b/R/standalone-survival.R @@ -63,10 +63,7 @@ } .check_cens_type <- - function(surv, - type = "right", - fail = TRUE, - call = rlang::caller_env()) { + function(surv, type = "right", fail = TRUE, call = rlang::caller_env()) { .is_surv(surv, call = call) obj_type <- .extract_surv_type(surv) good_type <- all(obj_type %in% type) @@ -103,13 +100,14 @@ .extract_surv_status <- function(surv) { .is_surv(surv) - res <- surv[, "status"] + res <- surv[, "status"] un_vals <- sort(unique(res)) event_type_to_01 <- !(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate")) if ( event_type_to_01 && - (identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) { + (identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) + ) { res <- res - 1 } unname(res) diff --git a/R/survival-censoring-model.R b/R/survival-censoring-model.R index 91fb78f0e..b5cfbff38 100644 --- a/R/survival-censoring-model.R +++ b/R/survival-censoring-model.R @@ -2,13 +2,24 @@ # tested in tidymodels/extratests#67 new_reverse_km_fit <- - function(formula, - object, - pkgs = character(0), - label = character(0), - extra_cls = character(0)) { - res <- list(formula = formula, fit = object, label = label, required_pkgs = pkgs) - class(res) <- c(paste0("censoring_model_", label), "censoring_model", extra_cls) + function( + formula, + object, + pkgs = character(0), + label = character(0), + extra_cls = character(0) + ) { + res <- list( + formula = formula, + fit = object, + label = label, + required_pkgs = pkgs + ) + class(res) <- c( + paste0("censoring_model_", label), + "censoring_model", + extra_cls + ) res } @@ -24,7 +35,7 @@ reverse_km <- function(obj, eval_env) { # Note: even when fit_xy() is called, eval_env will still have # objects data and formula in them f <- eval_env$formula - km_form <- stats::update(f, ~ 1) + km_form <- stats::update(f, ~1) cl <- rlang::call2( "prodlim", @@ -61,7 +72,13 @@ predict.censoring_model <- function(object, ...) { } #' @export -predict.censoring_model_reverse_km <- function(object, new_data, time, as_vector = FALSE, ...) { +predict.censoring_model_reverse_km <- function( + object, + new_data, + time, + as_vector = FALSE, + ... +) { rlang::check_dots_empty() rlang::check_installed("prodlim", version = "2022.10.13") diff --git a/R/survival-censoring-weights.R b/R/survival-censoring-weights.R index 08bea6ad6..8a31c2ac5 100644 --- a/R/survival-censoring-weights.R +++ b/R/survival-censoring-weights.R @@ -191,10 +191,14 @@ graf_weight_time_vec <- function(surv_obj, eval_time, eps = 10^-10) { #' @export #' @rdname censoring_weights -.censoring_weights_graf.model_fit <- function(object, - predictions, - cens_predictors = NULL, - trunc = 0.05, eps = 10^-10, ...) { +.censoring_weights_graf.model_fit <- function( + object, + predictions, + cens_predictors = NULL, + trunc = 0.05, + eps = 10^-10, + ... +) { rlang::check_dots_empty() .check_censor_model(object) truth <- .find_surv_col(predictions) @@ -205,30 +209,42 @@ graf_weight_time_vec <- function(surv_obj, eval_time, eps = 10^-10) { cli::cli_warn("{.arg cens_predictors} is not currently used.") } predictions$.pred <- - add_graf_weights_vec(object, - predictions$.pred, - predictions[[truth]], - trunc = trunc, - eps = eps) + add_graf_weights_vec( + object, + predictions$.pred, + predictions[[truth]], + trunc = trunc, + eps = eps + ) predictions } # ------------------------------------------------------------------------------ # Helpers -add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10^-10) { +add_graf_weights_vec <- function( + object, + .pred, + surv_obj, + trunc = 0.05, + eps = 10^-10 +) { # Expand the list column to one data frame n <- length(.pred) num_times <- vctrs::list_sizes(.pred) y <- vctrs::list_unchop(.pred) y$surv_obj <- vctrs::vec_rep_each(surv_obj, times = num_times) - names(y)[names(y) == ".time"] <- ".eval_time" # Temporary + names(y)[names(y) == ".time"] <- ".eval_time" # Temporary # Compute the actual time of evaluation y$.weight_time <- graf_weight_time_vec(y$surv_obj, y$.eval_time, eps = eps) # Compute the corresponding probability of being censored - y$.pred_censored <- predict(object$censor_probs, time = y$.weight_time, as_vector = TRUE) + y$.pred_censored <- predict( + object$censor_probs, + time = y$.weight_time, + as_vector = TRUE + ) y$.pred_censored <- trunc_probs(y$.pred_censored, trunc = trunc) # Invert the probabilities to create weights y$.weight_censored = 1 / y$.pred_censored @@ -243,7 +259,10 @@ add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10 is_surv <- purrr::map_lgl(x[!is_lst_col], .is_surv, fail = FALSE) num_surv <- sum(is_surv) if (fail && num_surv != 1) { - cli::cli_abort("There should be a single column of class {.cls Surv}.", call = call) + cli::cli_abort( + "There should be a single column of class {.cls Surv}.", + call = call + ) } names(is_surv)[is_surv] } diff --git a/R/survival_reg.R b/R/survival_reg.R index d34781ee9..7c222bf58 100644 --- a/R/survival_reg.R +++ b/R/survival_reg.R @@ -31,8 +31,11 @@ #' survival_reg(mode = "censored regression", dist = "weibull") #' @keywords internal #' @export -survival_reg <- function(mode = "censored regression", engine = "survival", dist = NULL) { - +survival_reg <- function( + mode = "censored regression", + engine = "survival", + dist = NULL +) { args <- list( dist = enquo(dist) ) @@ -54,8 +57,13 @@ survival_reg <- function(mode = "censored regression", engine = "survival", dist #' @method update survival_reg #' @rdname parsnip_update #' @export -update.survival_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) { - +update.survival_reg <- function( + object, + parameters = NULL, + dist = NULL, + fresh = FALSE, + ... +) { args <- list( dist = enquo(dist) ) @@ -83,14 +91,13 @@ translate.survival_reg <- function(x, engine = x$engine, ...) { #' @export check_args.survival_reg <- function(object, call = rlang::caller_env()) { - if (object$engine == "flexsurv") { - args <- lapply(object$args, rlang::eval_tidy) # `dist` has no default in the function - if (all(names(args) != "dist") || is.null(args$dist)) + if (all(names(args) != "dist") || is.null(args$dist)) { object$args$dist <- "weibull" + } } invisible(object) diff --git a/R/survival_reg_data.R b/R/survival_reg_data.R index 659e642b0..a999de57b 100644 --- a/R/survival_reg_data.R +++ b/R/survival_reg_data.R @@ -1,4 +1,3 @@ - set_new_model("survival_reg") set_model_mode("survival_reg", "censored regression") diff --git a/R/svm_linear.R b/R/svm_linear.R index c45bb6fea..b6e20f5bb 100644 --- a/R/svm_linear.R +++ b/R/svm_linear.R @@ -33,11 +33,9 @@ #' @export svm_linear <- - function(mode = "unknown", engine = "LiblineaR", - cost = NULL, margin = NULL) { - + function(mode = "unknown", engine = "LiblineaR", cost = NULL, margin = NULL) { args <- list( - cost = enquo(cost), + cost = enquo(cost), margin = enquo(margin) ) @@ -59,15 +57,17 @@ svm_linear <- #' @rdname parsnip_update #' @export update.svm_linear <- - function(object, - parameters = NULL, - cost = NULL, margin = NULL, - fresh = FALSE, - ...) { - + function( + object, + parameters = NULL, + cost = NULL, + margin = NULL, + fresh = FALSE, + ... + ) { args <- list( - cost = enquo(cost), - margin = enquo(margin) + cost = enquo(cost), + margin = enquo(margin) ) update_spec( @@ -93,7 +93,6 @@ translate.svm_linear <- function(x, engine = x$engine, ...) { # add checks to error trap or change things for this method if (x$engine == "LiblineaR") { - if (is_null(x$eng_args$type)) { liblinear_type <- NULL } else { @@ -101,33 +100,37 @@ translate.svm_linear <- function(x, engine = x$engine, ...) { } if (x$mode == "regression") { - if (is_null(quo_get_expr(x$args$margin))) + if (is_null(quo_get_expr(x$args$margin))) { arg_vals$svr_eps <- 0.1 - if (!is_null(liblinear_type)) - if(!liblinear_type %in% 11:13) + } + if (!is_null(liblinear_type)) { + if (!liblinear_type %in% 11:13) { cli::cli_abort( "The LiblineaR engine argument {.code type = {liblinear_type}} does not correspond to an SVM regression model." ) + } + } } else if (x$mode == "classification") { - if (!is_null(liblinear_type)) + if (!is_null(liblinear_type)) { if (!liblinear_type %in% 1:5) { cli::cli_abort( "The LiblineaR engine argument of {.code type = {liblinear_type}} does not correspond to an SVM classification model." ) } + } } } if (x$engine == "kernlab") { - # unless otherwise specified, classification models predict probabilities - if (x$mode == "classification" && !any(arg_names == "prob.model")) + if (x$mode == "classification" && !any(arg_names == "prob.model")) { arg_vals$prob.model <- TRUE - if (x$mode == "classification" && any(arg_names == "epsilon")) + } + if (x$mode == "classification" && any(arg_names == "epsilon")) { arg_vals$epsilon <- NULL - + } } x$method$fit$args <- arg_vals @@ -150,6 +153,5 @@ svm_linear_post <- function(results, object) { } svm_reg_linear_post <- function(results, object) { - results[,1] + results[, 1] } - diff --git a/R/svm_linear_data.R b/R/svm_linear_data.R index 36543bdae..a3926ce84 100644 --- a/R/svm_linear_data.R +++ b/R/svm_linear_data.R @@ -87,11 +87,10 @@ set_pred( pre = NULL, post = svm_linear_post, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(new_data) - ) + args = list( + object = quote(object$fit), + newx = quote(new_data) + ) ) ) set_pred( @@ -105,7 +104,8 @@ set_pred( func = c(fun = "predict"), args = list( object = quote(object$fit), - newx = quote(new_data)) + newx = quote(new_data) + ) ) ) set_pred( @@ -117,11 +117,10 @@ set_pred( pre = NULL, post = svm_linear_post, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = expr(as.matrix(new_data)) - ) + args = list( + object = quote(object$fit), + newx = expr(as.matrix(new_data)) + ) ) ) set_pred( @@ -135,7 +134,8 @@ set_pred( func = c(fun = "predict"), args = list( object = quote(object$fit), - newx = quote(new_data)) + newx = quote(new_data) + ) ) ) @@ -210,12 +210,11 @@ set_pred( pre = NULL, post = svm_reg_linear_post, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -253,12 +252,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -271,12 +269,11 @@ set_pred( pre = NULL, post = function(result, object) as_tibble(result), func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) ) ) @@ -292,4 +289,3 @@ set_pred( args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) - diff --git a/R/svm_poly.R b/R/svm_poly.R index 4acd1afe8..2a01b86f1 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -36,13 +36,18 @@ #' @export svm_poly <- - function(mode = "unknown", engine = "kernlab", - cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL) { - + function( + mode = "unknown", + engine = "kernlab", + cost = NULL, + degree = NULL, + scale_factor = NULL, + margin = NULL + ) { args <- list( - cost = enquo(cost), - degree = enquo(degree), - scale_factor = enquo(scale_factor), + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), margin = enquo(margin) ) @@ -64,17 +69,21 @@ svm_poly <- #' @rdname parsnip_update #' @export update.svm_poly <- - function(object, - parameters = NULL, - cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, - fresh = FALSE, - ...) { - + function( + object, + parameters = NULL, + cost = NULL, + degree = NULL, + scale_factor = NULL, + margin = NULL, + fresh = FALSE, + ... + ) { args <- list( - cost = enquo(cost), - degree = enquo(degree), - scale_factor = enquo(scale_factor), - margin = enquo(margin) + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), + margin = enquo(margin) ) update_spec( @@ -99,12 +108,13 @@ translate.svm_poly <- function(x, engine = x$engine, ...) { # add checks to error trap or change things for this method if (x$engine == "kernlab") { - # unless otherwise specified, classification models predict probabilities - if (x$mode == "classification" && !any(arg_names == "prob.model")) + if (x$mode == "classification" && !any(arg_names == "prob.model")) { arg_vals$prob.model <- TRUE - if (x$mode == "classification" && any(arg_names == "epsilon")) + } + if (x$mode == "classification" && any(arg_names == "epsilon")) { arg_vals$epsilon <- NULL + } # convert degree and scale to a `kpar` argument. if (any(arg_names %in% c("degree", "scale", "offset"))) { @@ -123,7 +133,6 @@ translate.svm_poly <- function(x, engine = x$engine, ...) { } arg_vals$kpar <- kpar } - } x$method$fit$args <- arg_vals @@ -141,6 +150,5 @@ check_args.svm_poly <- function(object, call = rlang::caller_env()) { # ------------------------------------------------------------------------------ svm_reg_post <- function(results, object) { - results[,1] + results[, 1] } - diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 69de824ca..66efc9978 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -103,12 +103,11 @@ set_pred( pre = NULL, post = svm_reg_post, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -134,12 +133,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -152,12 +150,11 @@ set_pred( pre = NULL, post = function(result, object) as_tibble(result), func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) ) ) @@ -173,4 +170,3 @@ set_pred( args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) - diff --git a/R/svm_rbf.R b/R/svm_rbf.R index ba8abf272..2b55ca1c6 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -38,12 +38,16 @@ #' @export svm_rbf <- - function(mode = "unknown", engine = "kernlab", - cost = NULL, rbf_sigma = NULL, margin = NULL) { - + function( + mode = "unknown", + engine = "kernlab", + cost = NULL, + rbf_sigma = NULL, + margin = NULL + ) { args <- list( - cost = enquo(cost), - rbf_sigma = enquo(rbf_sigma), + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), margin = enquo(margin) ) @@ -65,16 +69,19 @@ svm_rbf <- #' @rdname parsnip_update #' @export update.svm_rbf <- - function(object, - parameters = NULL, - cost = NULL, rbf_sigma = NULL, margin = NULL, - fresh = FALSE, - ...) { - + function( + object, + parameters = NULL, + cost = NULL, + rbf_sigma = NULL, + margin = NULL, + fresh = FALSE, + ... + ) { args <- list( - cost = enquo(cost), - rbf_sigma = enquo(rbf_sigma), - margin = enquo(margin) + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), + margin = enquo(margin) ) update_spec( @@ -99,12 +106,13 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) { # add checks to error trap or change things for this method if (x$engine == "kernlab") { - # unless otherwise specified, classification models predict probabilities - if (x$mode == "classification" && !any(arg_names == "prob.model")) + if (x$mode == "classification" && !any(arg_names == "prob.model")) { arg_vals$prob.model <- TRUE - if (x$mode == "classification" && any(arg_names == "epsilon")) + } + if (x$mode == "classification" && any(arg_names == "epsilon")) { arg_vals$epsilon <- NULL + } # convert sigma and scale to a `kpar` argument. if (any(arg_names == "sigma")) { @@ -113,7 +121,6 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) { arg_vals$sigma <- NULL arg_vals$kpar <- kpar } - } if (x$engine == "liquidSVM") { @@ -127,7 +134,6 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) { arg_vals$lambdas <- arg_vals$C arg_vals$C <- NULL } - } x$method$fit$args <- arg_vals @@ -146,6 +152,5 @@ check_args.svm_rbf <- function(object, call = rlang::caller_env()) { # ------------------------------------------------------------------------------ svm_reg_post <- function(results, object) { - results[,1] + results[, 1] } - diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index 5106ff4f1..28644e38f 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -83,12 +83,11 @@ set_pred( pre = NULL, post = svm_reg_post, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -126,12 +125,11 @@ set_pred( pre = NULL, post = NULL, func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) ) @@ -144,12 +142,11 @@ set_pred( pre = NULL, post = function(result, object) as_tibble(result), func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) ) ) @@ -251,11 +248,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) set_pred( @@ -269,7 +265,8 @@ set_pred( func = c(fun = "predict"), args = list( object = quote(object$fit), - newdata = quote(new_data)) + newdata = quote(new_data) + ) ) ) set_pred( @@ -281,11 +278,10 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) ) set_pred( @@ -311,12 +307,11 @@ set_pred( res }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - predict.prob = TRUE - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + predict.prob = TRUE + ) ) ) set_pred( @@ -330,6 +325,7 @@ set_pred( func = c(fun = "predict"), args = list( object = quote(object$fit), - newdata = quote(new_data)) + newdata = quote(new_data) + ) ) ) diff --git a/R/tidy_liblinear.R b/R/tidy_liblinear.R index 8cfb416e4..a5e80acab 100644 --- a/R/tidy_liblinear.R +++ b/R/tidy_liblinear.R @@ -10,7 +10,7 @@ tidy._LiblineaR <- function(x, ...) { check_installs(x$spec) - ret <- tibble(colnames(x$fit$W), x$fit$W[1,]) + ret <- tibble(colnames(x$fit$W), x$fit$W[1, ]) colnames(ret) <- c("term", "estimate") ret diff --git a/R/translate.R b/R/translate.R index fb6b020ba..2057c22d3 100644 --- a/R/translate.R +++ b/R/translate.R @@ -44,8 +44,9 @@ #' #' @export -translate <- function(x, ...) +translate <- function(x, ...) { UseMethod("translate") +} #' @rdname translate #' @export @@ -108,11 +109,22 @@ get_model_spec <- function(model, mode, engine) { libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) libs <- vctrs::vec_slice(libs$pkg, libs$engine == engine) - res$libs <- if (length(libs) > 0) {libs[[1]]} else {NULL} + res$libs <- if (length(libs) > 0) { + libs[[1]] + } else { + NULL + } fits <- rlang::env_get(m_env, paste0(model, "_fit")) - fits <- vctrs::vec_slice(fits$value, fits$mode == mode & fits$engine == engine) - res$fit <- if (length(fits) > 0) {fits[[1]]} else {NULL} + fits <- vctrs::vec_slice( + fits$value, + fits$mode == mode & fits$engine == engine + ) + res$fit <- if (length(fits) > 0) { + fits[[1]] + } else { + NULL + } preds <- rlang::env_get(m_env, paste0(model, "_predict")) where <- preds$mode == mode & preds$engine == engine @@ -151,7 +163,7 @@ deharmonize <- function(args, key) { dplyr::left_join(parsn, key, by = "parsnip") |> dplyr::arrange(order) - merged <- merged[!duplicated(merged$order),] + merged <- merged[!duplicated(merged$order), ] names(args) <- merged$original args[!is.na(merged$original)] @@ -222,4 +234,3 @@ add_methods <- function(x, engine) { } res } - diff --git a/R/tune_args.R b/R/tune_args.R index a37e68fe9..848d59f1b 100644 --- a/R/tune_args.R +++ b/R/tune_args.R @@ -1,7 +1,6 @@ #' @method tune_args model_spec #' @export tune_args.model_spec <- function(object, full = FALSE, ...) { - # use the model_spec top level class as the id model_type <- class(object)[1] @@ -10,7 +9,7 @@ tune_args.model_spec <- function(object, full = FALSE, ...) { } # Locate tunable args in spec args and engine specific args - object$args <- purrr::map(object$args, convert_args) + object$args <- purrr::map(object$args, convert_args) object$eng_args <- purrr::map(object$eng_args, convert_args) arg_id <- purrr::map_chr(object$args, find_tune_id) @@ -32,11 +31,9 @@ tune_args.model_spec <- function(object, full = FALSE, ...) { } - # helpers for tune_args() methods ----------------------------------------- # they also exist in recipes for the `tune_args()` methods there - # If we map over a list or arguments and some are quosures, we get the message # that "Subsetting quosures with `[[` is deprecated as of rlang 0.4.0" @@ -50,22 +47,27 @@ convert_args <- function(x) { # useful for standardization and for creating a 0 row tunable tbl # (i.e. for when there are no steps in a recipe) -tune_tbl <- function(name = character(), - tunable = logical(), - id = character(), - source = character(), - component = character(), - component_id = character(), - full = FALSE, - call = caller_env()) { - +tune_tbl <- function( + name = character(), + tunable = logical(), + id = character(), + source = character(), + component = character(), + component_id = character(), + full = FALSE, + call = caller_env() +) { check_bool(full, call = call) complete_id <- id[!is.na(id)] dups <- duplicated(complete_id) if (any(dups)) { - stop("There are duplicate `id` values listed in [tune()]: ", - paste0("'", unique(complete_id[dups]), "'", collapse = ", "), - ".", sep = "", call. = FALSE) + stop( + "There are duplicate `id` values listed in [tune()]: ", + paste0("'", unique(complete_id[dups]), "'", collapse = ", "), + ".", + sep = "", + call. = FALSE + ) } vry_tbl <- tibble::new_tibble( @@ -81,7 +83,7 @@ tune_tbl <- function(name = character(), ) if (!full) { - vry_tbl <- vry_tbl[vry_tbl$tunable,] + vry_tbl <- vry_tbl[vry_tbl$tunable, ] } vry_tbl @@ -126,7 +128,6 @@ tune_id <- function(x) { } find_tune_id <- function(x) { - # STEP 1 - Early exits # Early exit for empty elements (like list()) @@ -176,7 +177,8 @@ find_tune_id <- function(x) { "The current argument has: `", paste0(deparse(x), collapse = ""), "`.", - call. = FALSE) + call. = FALSE + ) } return(tunable_elems) diff --git a/R/type_sum.R b/R/type_sum.R index ab846d538..9b22e8ddd 100644 --- a/R/type_sum.R +++ b/R/type_sum.R @@ -16,11 +16,13 @@ #' @export type_sum.model_spec <- function(x) { resolved <- TRUE - if (x$mode == "unknown") + if (x$mode == "unknown") { resolved <- FALSE + } arg_info <- generics::tune_args(x) - if (any(arg_info$tunable)) + if (any(arg_info$tunable)) { resolved <- FALSE + } res <- "spec" if (resolved) { @@ -37,8 +39,9 @@ type_sum.model_spec <- function(x) { #' @export type_sum.model_fit <- function(x) { resolved <- TRUE - if (inherits(x$fit, "try-error")) + if (inherits(x$fit, "try-error")) { resolved <- FALSE + } res <- "fit" if (resolved) { diff --git a/R/update.R b/R/update.R index b242ffe5c..9760ded0e 100644 --- a/R/update.R +++ b/R/update.R @@ -58,8 +58,15 @@ NULL #' @export #' @keywords internal #' @rdname add_on_exports -update_spec <- function(object, parameters, args_enquo_list, fresh, cls, ..., - call = caller_env()) { +update_spec <- function( + object, + parameters, + args_enquo_list, + fresh, + cls, + ..., + call = caller_env() +) { check_bool(fresh, call = call) eng_args <- update_engine_parameters(object$eng_args, fresh, ...) @@ -74,12 +81,15 @@ update_spec <- function(object, parameters, args_enquo_list, fresh, cls, ..., object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) - if (any(null_args)) + if (any(null_args)) { args <- args[!null_args] - if (length(args) > 0) + } + if (length(args) > 0) { object$args[names(args)] <- args - if (length(eng_args) > 0) + } + if (length(eng_args) > 0) { object$eng_args[names(eng_args)] <- eng_args + } } new_model_spec( diff --git a/R/varying.R b/R/varying.R index de29c5e56..bd8c322bc 100644 --- a/R/varying.R +++ b/R/varying.R @@ -149,12 +149,13 @@ varying_args.step <- function(object, full = TRUE, ...) { # useful for standardization and for creating a 0 row varying tbl # (i.e. for when there are no steps in a recipe) -varying_tbl <- function(name = character(), - varying = logical(), - id = character(), - type = character(), - full = FALSE) { - +varying_tbl <- function( + name = character(), + varying = logical(), + id = character(), + type = character(), + full = FALSE +) { vry_tbl <- tibble( name = name, varying = varying, @@ -163,16 +164,14 @@ varying_tbl <- function(name = character(), ) if (!full) { - vry_tbl <- vry_tbl[vry_tbl$varying,] + vry_tbl <- vry_tbl[vry_tbl$varying, ] } vry_tbl } validate_only_allowed_step_args <- function(x, step_type) { - check_allowed_arg <- function(x, nm) { - # not varying if (rlang::is_false(x)) { return(invisible(x)) @@ -195,9 +194,18 @@ validate_only_allowed_step_args <- function(x, step_type) { } non_varying_step_arguments <- c( - "terms", "role", "trained", "skip", - "na.rm", "impute_with", "seed", - "prefix", "naming", "denom", "outcome", "id" + "terms", + "role", + "trained", + "skip", + "na.rm", + "impute_with", + "seed", + "prefix", + "naming", + "denom", + "outcome", + "id" ) # helpers ---------------------------------------------------------------------- @@ -219,7 +227,6 @@ is_varying <- function(x) { } find_varying <- function(x) { - # STEP 1 - Early exits # Early exit for empty elements (like list()) diff --git a/inst/add-in/gadget.R b/inst/add-in/gadget.R index abdaf9643..85e2e4031 100644 --- a/inst/add-in/gadget.R +++ b/inst/add-in/gadget.R @@ -41,7 +41,7 @@ parsnip_spec_add_in <- function() { cl_1 <- rlang::call2(.ns = pkg, .fn = x$model) } - obj_nm <- paste0(x$model,"_", x$engine, "_spec") + obj_nm <- paste0(x$model, "_", x$engine, "_spec") chr_1 <- rlang::expr_text(cl_1, width = 500) chr_1 <- paste0(chr_1, collapse = " ") chr_1 <- paste(obj_nm, "<-\n ", chr_1) @@ -88,28 +88,25 @@ parsnip_spec_add_in <- function() { ) ) - server <- function(input, output) { get_models <- reactive({ req(input$model_mode) - models <- model_db[model_db$mode == tolower(input$model_mode),] + models <- model_db[model_db$mode == tolower(input$model_mode), ] if (nchar(input$pattern) > 0) { - incld <- grepl(input$pattern, models$model) | grepl(input$pattern, models$engine) - models <- models[incld,] - + incld <- grepl(input$pattern, models$model) | + grepl(input$pattern, models$engine) + models <- models[incld, ] } models }) # get_models output$model_choices <- renderUI({ - model_list <- get_models() if (nrow(model_list) > 0) { - - choices <- paste0(model_list$model, " (", model_list$engine, ")") - choices <- unique(choices) + choices <- paste0(model_list$model, " (", model_list$engine, ")") + choices <- unique(choices) } else { choices <- NULL } @@ -122,19 +119,19 @@ parsnip_spec_add_in <- function() { }) # model_choices create_code <- reactive({ - req(input$model_name) req(input$model_mode) model_mode <- tolower(input$model_mode) - selected <- model_db[model_db$label %in% input$model_name,] - selected <- selected[selected$mode %in% model_mode,] + selected <- model_db[model_db$label %in% input$model_name, ] + selected <- selected[selected$mode %in% model_mode, ] - res <- purrr::map_chr(1:nrow(selected), - ~ make_spec(selected[.x,], tune_args = input$tune_args)) + res <- purrr::map_chr( + 1:nrow(selected), + ~ make_spec(selected[.x, ], tune_args = input$tune_args) + ) paste0(res, sep = "\n\n") - }) # create_code observeEvent(input$write, { @@ -154,4 +151,3 @@ parsnip_spec_add_in <- function() { } parsnip_spec_add_in() - diff --git a/inst/add-in/parsnip_model_db.R b/inst/add-in/parsnip_model_db.R index f1f6c5fc6..b3cf38b0b 100644 --- a/inst/add-in/parsnip_model_db.R +++ b/inst/add-in/parsnip_model_db.R @@ -7,23 +7,38 @@ library(tidymodels) library(usethis) # also requires installation of: -packages <- c("parsnip", "discrim", "plsmod", "rules", "baguette", "poissonreg", - "multilevelmod", "modeltime", "modeltime.gluonts") +packages <- c( + "parsnip", + "discrim", + "plsmod", + "rules", + "baguette", + "poissonreg", + "multilevelmod", + "modeltime", + "modeltime.gluonts" +) # ------------------------------------------------------------------------------ # Detects model specifications via their print methods print_methods <- function(x) { - require(x, character.only = TRUE) + require(x, character.only = TRUE) ns <- asNamespace(ns = x) mthds <- ls(envir = ns, pattern = "^print\\.") mthds <- gsub("^print\\.", "", mthds) - purrr::map(mthds, get_engines) |> purrr::list_rbind() |> dplyr::mutate(package = x) + purrr::map(mthds, get_engines) |> + purrr::list_rbind() |> + dplyr::mutate(package = x) } get_engines <- function(x) { eng <- try(parsnip::show_engines(x), silent = TRUE) if (inherits(eng, "try-error")) { - eng <- tibble::tibble(engine = NA_character_, mode = NA_character_, model = x) + eng <- tibble::tibble( + engine = NA_character_, + mode = NA_character_, + model = x + ) } else { eng$model <- x } @@ -42,22 +57,21 @@ get_tunable_param <- function(mode, package, model, engine) { # Edit some model parameters if (model == "rand_forest") { - res <- res[res$parameter != "trees",] + res <- res[res$parameter != "trees", ] } if (model == "mars") { - res <- res[res$parameter == "prod_degree",] + res <- res[res$parameter == "prod_degree", ] } if (engine %in% c("rule_fit", "xgboost")) { - res <- res[res$parameter != "mtry",] + res <- res[res$parameter != "mtry", ] } if (model %in% c("bag_tree", "bag_mars")) { - res <- res[0,] + res <- res[0, ] } if (engine %in% c("rpart")) { - res <- res[res$parameter != "tree-depth",] + res <- res[res$parameter != "tree-depth", ] } res - } # ------------------------------------------------------------------------------ @@ -82,7 +96,11 @@ num_modes <- model_db <- dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) |> - dplyr::mutate(parameters = purrr::pmap(list(mode, package, model, engine), get_tunable_param)) + dplyr::mutate( + parameters = purrr::pmap( + list(mode, package, model, engine), + get_tunable_param + ) + ) usethis::use_data(model_db, overwrite = TRUE) - diff --git a/tests/testthat.R b/tests/testthat.R index bcb3cd249..2a3c3f952 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -2,4 +2,3 @@ library(testthat) library(parsnip) test_check("parsnip") - diff --git a/tests/testthat/helper-extract_parameter_set.R b/tests/testthat/helper-extract_parameter_set.R index 25480d090..6c3bed465 100644 --- a/tests/testthat/helper-extract_parameter_set.R +++ b/tests/testthat/helper-extract_parameter_set.R @@ -1,5 +1,8 @@ check_parameter_set_tibble <- function(x) { - expect_equal(names(x), c("name", "id", "source", "component", "component_id", "object")) + expect_equal( + names(x), + c("name", "id", "source", "component", "component_id", "object") + ) expect_equal(class(x$name), "character") expect_equal(class(x$id), "character") expect_equal(class(x$source), "character") @@ -8,7 +11,9 @@ check_parameter_set_tibble <- function(x) { expect_true(!any(duplicated(x$id))) expect_equal(class(x$object), "list") - obj_check <- purrr::map_lgl(x$object, \(x) inherits(x, "param") | all(is.na(x))) + obj_check <- purrr::map_lgl(x$object, \(x) { + inherits(x, "param") | all(is.na(x)) + }) expect_true(all(obj_check)) invisible(TRUE) diff --git a/tests/testthat/test-args_and_modes.R b/tests/testthat/test-args_and_modes.R index 2fc2d55bb..3295991ad 100644 --- a/tests/testthat/test-args_and_modes.R +++ b/tests/testthat/test-args_and_modes.R @@ -25,7 +25,6 @@ test_that('pipe arguments', { ) expect_snapshot(error = TRUE, rand_forest() |> set_args()) - }) @@ -47,7 +46,6 @@ test_that("can't set a mode that isn't allowed by the model spec", { }) - test_that("unavailable modes for an engine and vice-versa", { expect_snapshot( decision_tree() |> @@ -76,7 +74,7 @@ test_that("unavailable modes for an engine and vice-versa", { ) expect_snapshot( - decision_tree(engine = NULL)|> + decision_tree(engine = NULL) |> set_mode("regression") |> set_engine("C5.0"), error = TRUE @@ -109,35 +107,20 @@ test_that("unavailable modes for an engine and vice-versa", { }) test_that("set_* functions error when input isn't model_spec", { - expect_snapshot(error = TRUE, - set_mode(mtcars, "regression") - ) + expect_snapshot(error = TRUE, set_mode(mtcars, "regression")) - expect_snapshot(error = TRUE, - set_args(mtcars, blah = "blah") - ) + expect_snapshot(error = TRUE, set_args(mtcars, blah = "blah")) - expect_snapshot(error = TRUE, - bag_tree |> set_mode("classification") - ) + expect_snapshot(error = TRUE, bag_tree |> set_mode("classification")) - expect_snapshot(error = TRUE, - bag_tree |> set_engine("rpart") - ) + expect_snapshot(error = TRUE, bag_tree |> set_engine("rpart")) - expect_snapshot(error = TRUE, - bag_tree |> set_args(boop = "bop") - ) + expect_snapshot(error = TRUE, bag_tree |> set_args(boop = "bop")) # won't raise "info" part of error if not a parsnip-namespaced function # not a function - expect_snapshot(error = TRUE, - 1L |> set_args(mode = "classification") - ) + expect_snapshot(error = TRUE, 1L |> set_args(mode = "classification")) # not from parsnip - expect_snapshot(error = TRUE, - bag_tree |> set_mode("classification") - ) + expect_snapshot(error = TRUE, bag_tree |> set_mode("classification")) }) - diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index 19b6c536c..a9d75feb7 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -6,31 +6,74 @@ test_that('regression models', { expect_equal( colnames(augment(reg_form, head(mtcars))), - c( ".pred", ".resid", - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb") + c( + ".pred", + ".resid", + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb" + ) ) expect_equal(nrow(augment(reg_form, head(mtcars))), 6) expect_equal( colnames(augment(reg_form, head(mtcars[, -1]))), - c(".pred", - "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb") + c( + ".pred", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb" + ) ) expect_equal(nrow(augment(reg_form, head(mtcars[, -1]))), 6) expect_equal( colnames(augment(reg_xy, head(mtcars))), - c(".pred", - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb") + c( + ".pred", + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb" + ) ) expect_equal(nrow(augment(reg_xy, head(mtcars))), 6) expect_equal( colnames(augment(reg_xy, head(mtcars[, -1]))), - c(".pred", - "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb") + c( + ".pred", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb" + ) ) expect_equal(nrow(augment(reg_xy, head(mtcars[, -1]))), 6) @@ -42,11 +85,9 @@ test_that('regression models', { error = TRUE, augment(reg_form, head(mtcars[, -1])) ) - }) - test_that('classification models', { skip_if_not_installed("modeldata") @@ -77,7 +118,6 @@ test_that('classification models', { c(".pred_class", ".pred_Class1", ".pred_Class2", "A", "B") ) expect_equal(nrow(augment(cls_xy, head(two_class_dat[, -3]))), 6) - }) @@ -94,12 +134,11 @@ test_that('augment for model without class probabilities', { c(".pred_class", "A", "B", "Class") ) expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6) - }) test_that('quantile regression models', { - probs_1 <- (1:5)/5 + probs_1 <- (1:5) / 5 expect_snapshot( linear_reg() |> set_mode("quantile regression", quantile_levels = probs_1) diff --git a/tests/testthat/test-bag_mars.R b/tests/testthat/test-bag_mars.R index d1fdcf5b3..dca351bc0 100644 --- a/tests/testthat/test-bag_mars.R +++ b/tests/testthat/test-bag_mars.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/baguette expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-bag_mlp.R b/tests/testthat/test-bag_mlp.R index d1fdcf5b3..dca351bc0 100644 --- a/tests/testthat/test-bag_mlp.R +++ b/tests/testthat/test-bag_mlp.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/baguette expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-bag_tree.R b/tests/testthat/test-bag_tree.R index d1fdcf5b3..dca351bc0 100644 --- a/tests/testthat/test-bag_tree.R +++ b/tests/testthat/test-bag_tree.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/baguette expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-boost_tree_C5.0.R b/tests/testthat/test-boost_tree_C5.0.R index 561f99d2a..341a88866 100644 --- a/tests/testthat/test-boost_tree_C5.0.R +++ b/tests/testthat/test-boost_tree_C5.0.R @@ -6,13 +6,12 @@ lending_club_fail <- dplyr::mutate(bad = Inf, miss = NA) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") lc_basic <- - boost_tree(mode = "classification") |> - set_engine("C5.0", bands = 2) + boost_tree(mode = "classification") |> + set_engine("C5.0", bands = 2) # ------------------------------------------------------------------------------ test_that('C5.0 execution', { - skip_if_not_installed("C50") expect_no_condition( @@ -68,7 +67,6 @@ test_that('C5.0 execution', { }) test_that('C5.0 prediction', { - skip_if_not_installed("C50") classes_xy <- fit_xy( @@ -78,13 +76,17 @@ test_that('C5.0 prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(classes_xy), newdata = lending_club[1:7, num_pred]) - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred])$.pred_class) - + xy_pred <- predict( + extract_fit_engine(classes_xy), + newdata = lending_club[1:7, num_pred] + ) + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred])$.pred_class + ) }) test_that('C5.0 probabilities', { - skip_if_not_installed("C50") classes_xy <- fit_xy( @@ -94,36 +96,63 @@ test_that('C5.0 probabilities', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(classes_xy), newdata = as.data.frame(lending_club[1:7, num_pred]), type = "prob") + xy_pred <- predict( + extract_fit_engine(classes_xy), + newdata = as.data.frame(lending_club[1:7, num_pred]), + type = "prob" + ) xy_pred <- as_tibble(xy_pred) names(xy_pred) <- c(".pred_bad", ".pred_good") - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "prob")) + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred], type = "prob") + ) one_row <- predict(classes_xy, lending_club[1, num_pred], type = "prob") - expect_equal(xy_pred[1,], one_row) - + expect_equal(xy_pred[1, ], one_row) }) test_that('submodel prediction', { - skip_if_not_installed("C50") library(C50) - vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") + vars <- c( + "female", + "tenure", + "total_charges", + "phone_service", + "monthly_charges" + ) class_fit <- boost_tree(trees = 20, mode = "classification") |> set_engine("C5.0", control = C5.0Control(earlyStopping = FALSE)) |> fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) - pred_class <- predict(extract_fit_engine(class_fit), wa_churn[1:4, vars], trials = 4, type = "prob") + pred_class <- predict( + extract_fit_engine(class_fit), + wa_churn[1:4, vars], + trials = 4, + type = "prob" + ) - mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 4, type = "prob") + mp_res <- multi_predict( + class_fit, + new_data = wa_churn[1:4, vars], + trees = 4, + type = "prob" + ) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], pred_class[, "No"], ignore_attr = TRUE) - expect_snapshot(error = TRUE, - multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob") + expect_snapshot( + error = TRUE, + multi_predict( + class_fit, + newdata = wa_churn[1:4, vars], + trees = 4, + type = "prob" + ) ) }) @@ -143,13 +172,12 @@ test_that('argument checks for data dimensions', { set_mode("classification") expect_snapshot( - f_fit <- spec |> fit(species ~ ., data = penguins) + f_fit <- spec |> fit(species ~ ., data = penguins) ) expect_snapshot( xy_fit <- spec |> fit_xy(x = penguins[, -1], y = penguins$species) ) - expect_equal(extract_fit_engine(f_fit)$control$minCases, nrow(penguins)) + expect_equal(extract_fit_engine(f_fit)$control$minCases, nrow(penguins)) expect_equal(extract_fit_engine(xy_fit)$control$minCases, nrow(penguins)) - }) diff --git a/tests/testthat/test-boost_tree_xgboost.R b/tests/testthat/test-boost_tree_xgboost.R index eba427009..0841f1faa 100644 --- a/tests/testthat/test-boost_tree_xgboost.R +++ b/tests/testthat/test-boost_tree_xgboost.R @@ -59,10 +59,12 @@ test_that('xgboost execution, classification', { ) }) - expect_equal(res_f$fit$evaluation_log, res_xy$fit$evaluation_log) + expect_equal(res_f$fit$evaluation_log, res_xy$fit$evaluation_log) expect_equal(res_f_wts$fit$evaluation_log, res_xy_wts$fit$evaluation_log) # Check to see if the case weights had an effect - expect_true(!isTRUE(all.equal(res_f$fit$evaluation_log, res_f_wts$fit$evaluation_log))) + expect_true( + !isTRUE(all.equal(res_f$fit$evaluation_log, res_f_wts$fit$evaluation_log)) + ) expect_true(has_multi_predict(res_xy)) expect_equal(multi_predict_args(res_xy), "trees") @@ -80,7 +82,6 @@ test_that('xgboost execution, classification', { test_that('xgboost classification prediction', { - skip_if_not_installed("xgboost") skip_on_cran() @@ -95,10 +96,20 @@ test_that('xgboost classification prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), newdata = xgb.DMatrix(data = as.matrix(hpc[1:8, num_pred])), type = "class") + xy_pred <- predict( + extract_fit_engine(xy_fit), + newdata = xgb.DMatrix(data = as.matrix(hpc[1:8, num_pred])), + type = "class" + ) xy_pred <- matrix(xy_pred, ncol = 4, byrow = TRUE) - xy_pred <- factor(levels(hpc$class)[apply(xy_pred, 1, which.max)], levels = levels(hpc$class)) - expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class) + xy_pred <- factor( + levels(hpc$class)[apply(xy_pred, 1, which.max)], + levels = levels(hpc$class) + ) + expect_equal( + xy_pred, + predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class + ) form_fit <- fit( hpc_xgboost, @@ -107,10 +118,20 @@ test_that('xgboost classification prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), newdata = xgb.DMatrix(data = as.matrix(hpc[1:8, num_pred])), type = "class") + form_pred <- predict( + extract_fit_engine(form_fit), + newdata = xgb.DMatrix(data = as.matrix(hpc[1:8, num_pred])), + type = "class" + ) form_pred <- matrix(form_pred, ncol = 4, byrow = TRUE) - form_pred <- factor(levels(hpc$class)[apply(form_pred, 1, which.max)], levels = levels(hpc$class)) - expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class) + form_pred <- factor( + levels(hpc$class)[apply(form_pred, 1, which.max)], + levels = levels(hpc$class) + ) + expect_equal( + form_pred, + predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class + ) }) @@ -131,7 +152,6 @@ bad_rf_reg <- set_engine("xgboost", sampsize = -10) test_that('xgboost execution, regression', { - skip_if_not_installed("xgboost") skip_on_cran() @@ -154,13 +174,10 @@ test_that('xgboost execution, regression', { control = ctrl ) ) - }) - test_that('xgboost regression prediction', { - skip_if_not_installed("xgboost") skip_on_cran() @@ -173,7 +190,10 @@ test_that('xgboost regression prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) + xy_pred <- predict( + extract_fit_engine(xy_fit), + newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])) + ) expect_equal(xy_pred, predict(xy_fit, new_data = mtcars[1:8, -1])$.pred) form_fit <- fit( @@ -183,15 +203,19 @@ test_that('xgboost regression prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) + form_pred <- predict( + extract_fit_engine(form_fit), + newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])) + ) expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred) - expect_equal(extract_fit_engine(form_fit)$params$objective, "reg:squarederror") - + expect_equal( + extract_fit_engine(form_fit)$params$objective, + "reg:squarederror" + ) }) - test_that('xgboost alternate objective', { skip_if_not_installed("xgboost") skip_on_cran() @@ -204,8 +228,11 @@ test_that('xgboost alternate objective', { set_mode("regression") xgb_fit <- spec |> fit(mpg ~ ., data = mtcars) - expect_equal(extract_fit_engine(xgb_fit)$params$objective, "reg:pseudohubererror") - expect_no_error(xgb_preds <- predict(xgb_fit, new_data = mtcars[1,])) + expect_equal( + extract_fit_engine(xgb_fit)$params$objective, + "reg:pseudohubererror" + ) + expect_no_error(xgb_preds <- predict(xgb_fit, new_data = mtcars[1, ])) expect_s3_class(xgb_preds, "data.frame") logregobj <- function(preds, dtrain) { @@ -223,12 +250,11 @@ test_that('xgboost alternate objective', { xgb_fit2 <- spec2 |> fit(vs ~ ., data = mtcars |> mutate(vs = as.factor(vs))) expect_equal(rlang::eval_tidy(xgb_fit2$spec$eng_args$objective), logregobj) - expect_no_error(xgb_preds2 <- predict(xgb_fit2, new_data = mtcars[1,-8])) + expect_no_error(xgb_preds2 <- predict(xgb_fit2, new_data = mtcars[1, -8])) expect_s3_class(xgb_preds2, "data.frame") }) test_that('submodel prediction', { - skip_if_not_installed("xgboost") skip_on_cran() @@ -239,31 +265,55 @@ test_that('submodel prediction', { set_engine("xgboost") |> fit(mpg ~ ., data = mtcars[-(1:4), ]) - x <- xgboost::xgb.DMatrix(as.matrix(mtcars[1:4, -1])) + x <- xgboost::xgb.DMatrix(as.matrix(mtcars[1:4, -1])) - pruned_pred <- predict(extract_fit_engine(reg_fit), x, iterationrange = c(1, 6)) + pruned_pred <- predict( + extract_fit_engine(reg_fit), + x, + iterationrange = c(1, 6) + ) mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], trees = 5) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred"]], pruned_pred) - - vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") + vars <- c( + "female", + "tenure", + "total_charges", + "phone_service", + "monthly_charges" + ) class_fit <- boost_tree(trees = 20, mode = "classification") |> set_engine("xgboost") |> fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)], control = ctrl) - x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars])) + x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars])) - pred_class <- predict(extract_fit_engine(class_fit), x, iterationrange = c(1, 6)) + pred_class <- predict( + extract_fit_engine(class_fit), + x, + iterationrange = c(1, 6) + ) - mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob") + mp_res <- multi_predict( + class_fit, + new_data = wa_churn[1:4, vars], + trees = 5, + type = "prob" + ) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_Yes"]], pred_class) - expect_snapshot(error = TRUE, - multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob") + expect_snapshot( + error = TRUE, + multi_predict( + class_fit, + newdata = wa_churn[1:4, vars], + trees = 5, + type = "prob" + ) ) }) @@ -282,7 +332,10 @@ test_that('validation sets', { fit(mpg ~ ., data = mtcars[-(1:4), ]) ) - expect_equal(colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], "validation_rmse") + expect_equal( + colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + "validation_rmse" + ) expect_no_condition( reg_fit <- @@ -291,7 +344,10 @@ test_that('validation sets', { fit(mpg ~ ., data = mtcars[-(1:4), ]) ) - expect_equal(colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], "validation_mae") + expect_equal( + colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + "validation_mae" + ) expect_no_condition( reg_fit <- @@ -300,7 +356,10 @@ test_that('validation sets', { fit(mpg ~ ., data = mtcars[-(1:4), ]) ) - expect_equal(colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], "training_mae") + expect_equal( + colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + "training_mae" + ) expect_snapshot( error = TRUE, @@ -309,7 +368,6 @@ test_that('validation sets', { set_engine("xgboost", validation = 3) |> fit(mpg ~ ., data = mtcars[-(1:4), ]) ) - }) @@ -329,7 +387,11 @@ test_that('early stopping', { fit(mpg ~ ., data = mtcars[-(1:4), ]) ) - expect_equal(extract_fit_engine(reg_fit)$niter - extract_fit_engine(reg_fit)$best_iteration, 5) + expect_equal( + extract_fit_engine(reg_fit)$niter - + extract_fit_engine(reg_fit)$best_iteration, + 5 + ) expect_true(extract_fit_engine(reg_fit)$niter < 200) expect_no_condition( @@ -374,34 +436,56 @@ test_that('xgboost data conversion', { expect_true(inherits(from_mat$data, "xgb.DMatrix")) expect_true(inherits(from_mat$watchlist$training, "xgb.DMatrix")) - expect_no_condition(from_sparse <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg)) + expect_no_condition( + from_sparse <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg) + ) expect_true(inherits(from_mat$data, "xgb.DMatrix")) expect_true(inherits(from_mat$watchlist$training, "xgb.DMatrix")) - expect_no_condition(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, validation = .1)) + expect_no_condition( + from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, validation = .1) + ) expect_true(inherits(from_df$data, "xgb.DMatrix")) expect_true(inherits(from_df$watchlist$validation, "xgb.DMatrix")) expect_true(nrow(from_df$data) > nrow(from_df$watchlist$validation)) - expect_no_condition(from_mat <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, validation = .1)) + expect_no_condition( + from_mat <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, validation = .1) + ) expect_true(inherits(from_mat$data, "xgb.DMatrix")) expect_true(inherits(from_mat$watchlist$validation, "xgb.DMatrix")) expect_true(nrow(from_mat$data) > nrow(from_mat$watchlist$validation)) - expect_no_condition(from_sparse <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, validation = .1)) + expect_no_condition( + from_sparse <- parsnip:::as_xgb_data( + mtcar_smat, + mtcars$mpg, + validation = .1 + ) + ) expect_true(inherits(from_mat$data, "xgb.DMatrix")) expect_true(inherits(from_mat$watchlist$validation, "xgb.DMatrix")) expect_true(nrow(from_sparse$data) > nrow(from_sparse$watchlist$validation)) # set event_level for factors - mtcars_y <- factor(mtcars$mpg < 15, levels = c(TRUE, FALSE), labels = c("low", "high")) + mtcars_y <- factor( + mtcars$mpg < 15, + levels = c(TRUE, FALSE), + labels = c("low", "high") + ) expect_no_condition(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y)) - expect_equal(xgboost::getinfo(from_df$data, name = "label")[1:5], rep(0, 5)) - expect_no_condition(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y, event_level = "second")) - expect_equal(xgboost::getinfo(from_df$data, name = "label")[1:5], rep(1, 5)) + expect_equal(xgboost::getinfo(from_df$data, name = "label")[1:5], rep(0, 5)) + expect_no_condition( + from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y, event_level = "second") + ) + expect_equal(xgboost::getinfo(from_df$data, name = "label")[1:5], rep(1, 5)) - mtcars_y <- factor(mtcars$mpg < 15, levels = c(TRUE, FALSE, "na"), labels = c("low", "high", "missing")) + mtcars_y <- factor( + mtcars$mpg < 15, + levels = c(TRUE, FALSE, "na"), + labels = c("low", "high", "missing") + ) expect_snapshot( from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y, event_level = "second") ) @@ -412,11 +496,15 @@ test_that('xgboost data conversion', { ) expect_equal(wts, xgboost::getinfo(wted$data, "weight")) expect_no_condition( - wted_val <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, weights = wts, validation = 1/4) + wted_val <- parsnip:::as_xgb_data( + mtcar_x, + mtcars$mpg, + weights = wts, + validation = 1 / 4 + ) ) expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts)) expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight")) - }) @@ -447,9 +535,16 @@ test_that('xgboost data and sparse matrices', { from_mat$fit$handle <- NULL from_sparse$fit$handle <- NULL - - expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_mat), ignore_function_env = TRUE) - expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_sparse), ignore_function_env = TRUE) + expect_equal( + extract_fit_engine(from_df), + extract_fit_engine(from_mat), + ignore_function_env = TRUE + ) + expect_equal( + extract_fit_engine(from_df), + extract_fit_engine(from_sparse), + ignore_function_env = TRUE + ) # case weights added expect_no_condition( @@ -457,11 +552,15 @@ test_that('xgboost data and sparse matrices', { ) expect_equal(wts, xgboost::getinfo(wted$data, "weight")) expect_no_condition( - wted_val <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = wts, validation = 1/4) + wted_val <- parsnip:::as_xgb_data( + mtcar_smat, + mtcars$mpg, + weights = wts, + validation = 1 / 4 + ) ) expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts)) expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight")) - }) @@ -486,20 +585,25 @@ test_that('argument checks for data dimensions', { penguins_dummy <- as.data.frame(penguins_dummy[, -1]) expect_snapshot( - f_fit <- spec |> fit(species ~ ., data = penguins, control = ctrl) + f_fit <- spec |> fit(species ~ ., data = penguins, control = ctrl) ) expect_snapshot( - xy_fit <- spec |> fit_xy(x = penguins_dummy, y = penguins$species, control = ctrl) + xy_fit <- spec |> + fit_xy(x = penguins_dummy, y = penguins$species, control = ctrl) ) expect_equal(extract_fit_engine(f_fit)$params$colsample_bynode, 1) - expect_equal(extract_fit_engine(f_fit)$params$min_child_weight, nrow(penguins)) + expect_equal( + extract_fit_engine(f_fit)$params$min_child_weight, + nrow(penguins) + ) expect_equal(extract_fit_engine(xy_fit)$params$colsample_bynode, 1) - expect_equal(extract_fit_engine(xy_fit)$params$min_child_weight, nrow(penguins)) - + expect_equal( + extract_fit_engine(xy_fit)$params$min_child_weight, + nrow(penguins) + ) }) test_that("fit and prediction with `event_level`", { - skip_if_not_installed("xgboost") skip_on_cran() skip_if_not_installed("modeldata") @@ -513,28 +617,35 @@ test_that("fit and prediction with `event_level`", { train_y_1 <- -as.numeric(penguins$sex[-(1:4)]) + 2 train_y_2 <- as.numeric(penguins$sex[-(1:4)]) - 1 - x_pred <- xgboost::xgb.DMatrix(as.matrix(penguins[1:4, -5])) + x_pred <- xgboost::xgb.DMatrix(as.matrix(penguins[1:4, -5])) # event_level = "first" set.seed(24) fit_p_1 <- boost_tree(trees = 10) |> - set_engine("xgboost", eval_metric = "auc" - # event_level = "first" is the default - ) |> + set_engine( + "xgboost", + eval_metric = "auc" + # event_level = "first" is the default + ) |> set_mode("classification") |> fit(sex ~ ., data = penguins[-(1:4), ]) xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1) set.seed(24) - fit_xgb_1 <- xgboost::xgb.train(data = xgbmat_train_1, - nrounds = 10, - watchlist = list("training" = xgbmat_train_1), - objective = "binary:logistic", - eval_metric = "auc", - verbose = 0) + fit_xgb_1 <- xgboost::xgb.train( + data = xgbmat_train_1, + nrounds = 10, + watchlist = list("training" = xgbmat_train_1), + objective = "binary:logistic", + eval_metric = "auc", + verbose = 0 + ) - expect_equal(extract_fit_engine(fit_p_1)$evaluation_log, fit_xgb_1$evaluation_log) + expect_equal( + extract_fit_engine(fit_p_1)$evaluation_log, + fit_xgb_1$evaluation_log + ) pred_xgb_1 <- predict(fit_xgb_1, x_pred) pred_p_1 <- predict(fit_p_1, new_data = penguins[1:4, ], type = "prob") @@ -543,27 +654,30 @@ test_that("fit and prediction with `event_level`", { # event_level = "second" set.seed(24) fit_p_2 <- boost_tree(trees = 10) |> - set_engine("xgboost", eval_metric = "auc", - event_level = "second") |> + set_engine("xgboost", eval_metric = "auc", event_level = "second") |> set_mode("classification") |> fit(sex ~ ., data = penguins[-(1:4), ]) xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2) set.seed(24) - fit_xgb_2 <- xgboost::xgb.train(data = xgbmat_train_2, - nrounds = 10, - watchlist = list("training" = xgbmat_train_2), - objective = "binary:logistic", - eval_metric = "auc", - verbose = 0) + fit_xgb_2 <- xgboost::xgb.train( + data = xgbmat_train_2, + nrounds = 10, + watchlist = list("training" = xgbmat_train_2), + objective = "binary:logistic", + eval_metric = "auc", + verbose = 0 + ) - expect_equal(extract_fit_engine(fit_p_2)$evaluation_log, fit_xgb_2$evaluation_log) + expect_equal( + extract_fit_engine(fit_p_2)$evaluation_log, + fit_xgb_2$evaluation_log + ) pred_xgb_2 <- predict(fit_xgb_2, x_pred) pred_p_2 <- predict(fit_p_2, new_data = penguins[1:4, ], type = "prob") expect_equal(pred_p_2[[".pred_male"]], pred_xgb_2) - }) test_that("count/proportion parameters", { @@ -578,15 +692,24 @@ test_that("count/proportion parameters", { set_mode("regression") |> fit(mpg ~ ., data = mtcars) expect_equal(extract_fit_engine(fit1)$params$colsample_bytree, 1) - expect_equal(extract_fit_engine(fit1)$params$colsample_bynode, 7/(ncol(mtcars) - 1)) + expect_equal( + extract_fit_engine(fit1)$params$colsample_bynode, + 7 / (ncol(mtcars) - 1) + ) fit2 <- boost_tree(mtry = 7, trees = 4) |> set_engine("xgboost", colsample_bytree = 4) |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) - expect_equal(extract_fit_engine(fit2)$params$colsample_bytree, 4/(ncol(mtcars) - 1)) - expect_equal(extract_fit_engine(fit2)$params$colsample_bynode, 7/(ncol(mtcars) - 1)) + expect_equal( + extract_fit_engine(fit2)$params$colsample_bytree, + 4 / (ncol(mtcars) - 1) + ) + expect_equal( + extract_fit_engine(fit2)$params$colsample_bynode, + 7 / (ncol(mtcars) - 1) + ) fit3 <- boost_tree(trees = 4) |> @@ -611,7 +734,6 @@ test_that("count/proportion parameters", { set_mode("regression") |> fit(mpg ~ ., data = mtcars) ) - }) test_that('interface to param arguments', { @@ -659,7 +781,10 @@ test_that('interface to param arguments', { class = "xgboost_params_warning" ) - expect_equal(extract_fit_engine(fit_3)$params$objective, "reg:pseudohubererror") + expect_equal( + extract_fit_engine(fit_3)$params$objective, + "reg:pseudohubererror" + ) # pass objective as main argument (good) spec_4 <- @@ -670,7 +795,10 @@ test_that('interface to param arguments', { fit_4 <- spec_4 |> fit(mpg ~ ., data = mtcars) ) - expect_equal(extract_fit_engine(fit_4)$params$objective, "reg:pseudohubererror") + expect_equal( + extract_fit_engine(fit_4)$params$objective, + "reg:pseudohubererror" + ) # pass a guarded argument as a main argument (bad) spec_5 <- diff --git a/tests/testthat/test-c5_rules.R b/tests/testthat/test-c5_rules.R index 0b3c40e3c..2cadaf221 100644 --- a/tests/testthat/test-c5_rules.R +++ b/tests/testthat/test-c5_rules.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/rules expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-case-weights.R b/tests/testthat/test-case-weights.R index dcbcd35a4..0d4653b9d 100644 --- a/tests/testthat/test-case-weights.R +++ b/tests/testthat/test-case-weights.R @@ -1,12 +1,10 @@ - test_that('case weights with xy method', { - skip_if_not_installed("C50") skip_if_not_installed("modeldata") data("two_class_dat", package = "modeldata") wts <- runif(nrow(two_class_dat)) - wts <- ifelse(wts < 1/5, 0, 1) + wts <- ifelse(wts < 1 / 5, 0, 1) two_class_subset <- two_class_dat[wts != 0, ] wts <- importance_weights(wts) @@ -45,13 +43,12 @@ test_that('case weights with xy method', { test_that('case weights with xy method - non-standard argument names', { - skip_if_not_installed("ranger") skip_if_not_installed("modeldata") data("two_class_dat", package = "modeldata") wts <- runif(nrow(two_class_dat)) - wts <- ifelse(wts < 1/5, 0, 1) + wts <- ifelse(wts < 1 / 5, 0, 1) two_class_subset <- two_class_dat[wts != 0, ] wts <- importance_weights(wts) @@ -82,14 +79,13 @@ test_that('case weights with xy method - non-standard argument names', { }) test_that('case weights with formula method', { - skip_if_not_installed("modeldata") data("ames", package = "modeldata") ames$Sale_Price <- log10(ames$Sale_Price) set.seed(1) wts <- runif(nrow(ames)) - wts <- ifelse(wts < 1/5, 0L, 1L) + wts <- ifelse(wts < 1 / 5, 0L, 1L) ames_subset <- ames[wts != 0, ] wts <- frequency_weights(wts) @@ -107,14 +103,13 @@ test_that('case weights with formula method', { }) test_that('case weights with formula method -- unregistered model spec', { - skip_if_not_installed("modeldata") data("ames", package = "modeldata") ames$Sale_Price <- log10(ames$Sale_Price) set.seed(1) wts <- runif(nrow(ames)) - wts <- ifelse(wts < 1/5, 0L, 1L) + wts <- ifelse(wts < 1 / 5, 0L, 1L) ames_subset <- ames[wts != 0, ] wts <- frequency_weights(wts) @@ -126,14 +121,13 @@ test_that('case weights with formula method -- unregistered model spec', { }) test_that('case weights with formula method that goes through `fit_xy()`', { - skip_if_not_installed("modeldata") data("ames", package = "modeldata") ames$Sale_Price <- log10(ames$Sale_Price) set.seed(1) wts <- runif(nrow(ames)) - wts <- ifelse(wts < 1/5, 0L, 1L) + wts <- ifelse(wts < 1 / 5, 0L, 1L) ames_subset <- ames[wts != 0, ] wts <- frequency_weights(wts) @@ -144,7 +138,8 @@ test_that('case weights with formula method that goes through `fit_xy()`', { x = ames[c("Longitude", "Latitude")], y = ames$Sale_Price, case_weights = wts - )) + ) + ) lm_sub_fit <- linear_reg() |> diff --git a/tests/testthat/test-condense_control.R b/tests/testthat/test-condense_control.R index 345bac5dc..3f77a7912 100644 --- a/tests/testthat/test-condense_control.R +++ b/tests/testthat/test-condense_control.R @@ -15,9 +15,7 @@ test_that("condense_control works", { ) ctrl$anotherone <- 2 - expect_snapshot(error = TRUE, - condense_control(control_parsnip(), ctrl) - ) + expect_snapshot(error = TRUE, condense_control(control_parsnip(), ctrl)) # Emulate being called from one of the upstream control_* functions control_test <- function(control = control_parsnip()) { diff --git a/tests/testthat/test-convert_data.R b/tests/testthat/test-convert_data.R index 58d58f24a..64e664d7e 100644 --- a/tests/testthat/test-convert_data.R +++ b/tests/testthat/test-convert_data.R @@ -5,10 +5,11 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # to go from lm_object$x results to our format format_x_for_test <- function(x, df = TRUE) { x <- x[, colnames(x) != "(Intercept)", drop = FALSE] - if (df) + if (df) { as.data.frame(x) - else + } else { x + } } Puromycin_miss <- Puromycin @@ -35,16 +36,16 @@ test_that("numeric x and y", { expect_null(observed$offset) expect_equal( - mtcars[1:6,-1], + mtcars[1:6, -1], .convert_form_to_xy_new( observed, - new_data = head(mtcars))$x + new_data = head(mtcars) + )$x ) }) test_that("numeric x and y, subsetting", { - expected <- lm(mpg ~ ., data = mtcars, subset = hp > 170, - x = TRUE, y = TRUE) + expected <- lm(mpg ~ ., data = mtcars, subset = hp > 170, x = TRUE, y = TRUE) observed <- .convert_form_to_xy_fit( mpg ~ ., @@ -63,8 +64,13 @@ test_that("numeric x and y, subsetting", { }) test_that("numeric x and y, weights", { - expected <- lm(mpg ~ . -disp, data = mtcars, weights = disp, - x = TRUE, y = TRUE) + expected <- lm( + mpg ~ . - disp, + data = mtcars, + weights = disp, + x = TRUE, + y = TRUE + ) observed <- .convert_form_to_xy_fit( mpg ~ . - disp, @@ -82,13 +88,15 @@ test_that("numeric x and y, weights", { }) test_that("numeric x and y, offset in-line", { - expected <- lm(mpg ~ cyl + hp + offset(log(disp)), - data = mtcars, - x = TRUE, - y = TRUE) + expected <- lm( + mpg ~ cyl + hp + offset(log(disp)), + data = mtcars, + x = TRUE, + y = TRUE + ) observed <- .convert_form_to_xy_fit( - mpg ~ cyl + hp + offset(log(disp)), + mpg ~ cyl + hp + offset(log(disp)), data = mtcars, indicators = "traditional", remove_intercept = TRUE @@ -109,15 +117,14 @@ test_that("numeric x and y, offset in-line", { test_that("numeric x and y, multiple offsets in-line", { expected <- lm( - mpg ~ cyl + hp + offset(log(disp)) + offset(qsec), + mpg ~ cyl + hp + offset(log(disp)) + offset(qsec), data = mtcars, x = TRUE, y = TRUE ) observed <- .convert_form_to_xy_fit( - mpg ~ cyl + hp + offset(log(disp)) + - offset(qsec), + mpg ~ cyl + hp + offset(log(disp)) + offset(qsec), data = mtcars, indicators = "traditional", remove_intercept = TRUE @@ -132,15 +139,11 @@ test_that("numeric x and y, multiple offsets in-line", { new_obs <- .convert_form_to_xy_new(observed, new_data = mtcars[1:6, ]) expect_equal(mtcars[1:6, c("cyl", "hp")], new_obs$x) - expect_equal(log(mtcars$disp)[1:6] + mtcars$qsec[1:6], - new_obs$offset) + expect_equal(log(mtcars$disp)[1:6] + mtcars$qsec[1:6], new_obs$offset) }) test_that("numeric x and y, no intercept", { - expected <- lm(mpg ~ 0 + ., - data = mtcars, - x = TRUE, - y = TRUE) + expected <- lm(mpg ~ 0 + ., data = mtcars, x = TRUE, y = TRUE) observed <- .convert_form_to_xy_fit( mpg ~ 0 + ., @@ -155,15 +158,14 @@ test_that("numeric x and y, no intercept", { expect_null(observed$offset) expect_null(observed$weights) - expect_equal(mtcars[1:6,-1], - .convert_form_to_xy_new(observed, new_data = head(mtcars))$x) + expect_equal( + mtcars[1:6, -1], + .convert_form_to_xy_new(observed, new_data = head(mtcars))$x + ) }) test_that("numeric x and y, inline functions", { - expected <- lm(log(mpg) ~ hp + poly(wt, 3), - data = mtcars, - x = TRUE, - y = TRUE) + expected <- lm(log(mpg) ~ hp + poly(wt, 3), data = mtcars, x = TRUE, y = TRUE) observed <- .convert_form_to_xy_fit( log(mpg) ~ hp + poly(wt, 3), @@ -237,8 +239,7 @@ test_that("numeric y and mixed x, omit missing data", { remove_intercept = TRUE ) expect_equal(format_x_for_test(expected$x), observed$x) - expect_equal(Puromycin_miss$rate[complete.cases(Puromycin_miss)], - observed$y) + expect_equal(Puromycin_miss$rate[complete.cases(Puromycin_miss)], observed$y) expect_equal(expected$terms, observed$terms) expect_equal(expected$xlevels, observed$xlevels) expect_null(observed$weights) @@ -255,9 +256,7 @@ test_that("numeric y and mixed x, omit missing data", { }) test_that("numeric y and mixed x, include missing data", { - frame_obj <- model.frame(rate ~ ., - data = Puromycin_miss, - na.action = na.pass) + frame_obj <- model.frame(rate ~ ., data = Puromycin_miss, na.action = na.pass) expected <- model.matrix(rate ~ ., frame_obj) observed <- .convert_form_to_xy_fit( rate ~ ., @@ -292,7 +291,7 @@ test_that("numeric y and mixed x, fail missing data", { }) test_that("numeric y and mixed x, no dummies", { - expected <- model.frame(rate ~ ., data = Puromycin)[,-1] + expected <- model.frame(rate ~ ., data = Puromycin)[, -1] observed <- .convert_form_to_xy_fit( rate ~ ., @@ -309,10 +308,7 @@ test_that("numeric y and mixed x, no dummies", { }) test_that("numeric x and numeric multivariate y", { - expected <- lm(cbind(mpg, disp) ~ ., - data = mtcars, - x = TRUE, - y = TRUE) + expected <- lm(cbind(mpg, disp) ~ ., data = mtcars, x = TRUE, y = TRUE) observed <- .convert_form_to_xy_fit( cbind(mpg, disp) ~ ., @@ -327,8 +323,10 @@ test_that("numeric x and numeric multivariate y", { expect_null(observed$weights) expect_null(observed$offset) - expect_equal(mtcars[1:6,-c(1, 3)], - .convert_form_to_xy_new(observed, new_data = head(mtcars))$x) + expect_equal( + mtcars[1:6, -c(1, 3)], + .convert_form_to_xy_new(observed, new_data = head(mtcars))$x + ) }) test_that("numeric x and factor y", { @@ -350,7 +348,10 @@ test_that("numeric x and factor y", { ) expect_no_error( - observed2 <- .convert_form_to_xy_fit(class ~ ., data = hpc |> mutate(x = NA)) + observed2 <- .convert_form_to_xy_fit( + class ~ ., + data = hpc |> mutate(x = NA) + ) ) expect_equal(hpc$class[logical()], observed2$y) expect_s3_class(observed2$terms, "terms") @@ -363,7 +364,8 @@ test_that("bad args", { expect_snapshot( error = TRUE, .convert_form_to_xy_fit( - mpg ~ ., data = mtcars, + mpg ~ ., + data = mtcars, composition = "tibble", indicators = "traditional", remove_intercept = TRUE @@ -372,7 +374,8 @@ test_that("bad args", { expect_snapshot( error = TRUE, .convert_form_to_xy_fit( - mpg ~ ., data = mtcars, + mpg ~ ., + data = mtcars, weights = letters[1:nrow(mtcars)], indicators = "traditional", remove_intercept = TRUE @@ -394,21 +397,20 @@ test_that("numeric x and y, matrix composition", { expect_equal(mtcars$mpg, observed$y) new_obs <- - .convert_form_to_xy_new(observed, - new_data = head(mtcars), - composition = "matrix") - expect_equal(as.matrix(mtcars[1:6,-1]), new_obs$x) + .convert_form_to_xy_new( + observed, + new_data = head(mtcars), + composition = "matrix" + ) + expect_equal(as.matrix(mtcars[1:6, -1]), new_obs$x) }) test_that("numeric x and multivariate y, matrix composition", { expected <- - lm(cbind(mpg, cyl) ~ ., - data = mtcars, - x = TRUE, - y = TRUE) + lm(cbind(mpg, cyl) ~ ., data = mtcars, x = TRUE, y = TRUE) observed <- .convert_form_to_xy_fit( - cbind(mpg, cyl) ~ ., + cbind(mpg, cyl) ~ ., data = mtcars, composition = "matrix", indicators = "traditional", @@ -418,10 +420,12 @@ test_that("numeric x and multivariate y, matrix composition", { expect_equal(expected$y, observed$y) new_obs <- - .convert_form_to_xy_new(observed, - new_data = head(mtcars), - composition = "matrix") - expect_equal(as.matrix(mtcars[1:6,-(1:2)]), new_obs$x) + .convert_form_to_xy_new( + observed, + new_data = head(mtcars), + composition = "matrix" + ) + expect_equal(as.matrix(mtcars[1:6, -(1:2)]), new_obs$x) }) test_that("global `contrasts` option is respected", { @@ -437,14 +441,20 @@ test_that("global `contrasts` option is respected", { ) fit_data <- fit_result$x - expect_identical(names(fit_data), c("class1", "class2", "class3", "compounds")) + expect_identical( + names(fit_data), + c("class1", "class2", "class3", "compounds") + ) expect_true(all(fit_data$class1 %in% c(-1, 0, 1))) # Predict time predict_result <- .convert_form_to_xy_new(fit_result, hpc) predict_data <- predict_result$x - expect_identical(names(predict_data), c("class1", "class2", "class3", "compounds")) + expect_identical( + names(predict_data), + c("class1", "class2", "class3", "compounds") + ) expect_true(all(predict_data$class1 %in% c(-1, 0, 1))) }) @@ -461,15 +471,20 @@ test_that("data frame x, vector y", { expect_equal(names(mtcars)[-1], observed$x_var) expect_null(observed$weights) - expect_equal(mtcars[1:6, -1], - .convert_xy_to_form_new(observed, new_data = head(mtcars[,-1]))) + expect_equal( + mtcars[1:6, -1], + .convert_xy_to_form_new(observed, new_data = head(mtcars[, -1])) + ) }) test_that("matrix x, vector y", { observed <- - .convert_xy_to_form_fit(as.matrix(mtcars[,-1]), mtcars$mpg, - remove_intercept = TRUE) + .convert_xy_to_form_fit( + as.matrix(mtcars[, -1]), + mtcars$mpg, + remove_intercept = TRUE + ) expected <- mtcars[, c(2:11, 1)] names(expected)[11] <- "..y" expect_equal(expected, observed$data) @@ -478,7 +493,7 @@ test_that("matrix x, vector y", { expect_null(observed$weights) expect_equal( - mtcars[1:6,-1], + mtcars[1:6, -1], .convert_xy_to_form_new(observed, new_data = as.matrix(mtcars[1:6, -1])) ) }) @@ -486,8 +501,11 @@ test_that("matrix x, vector y", { test_that("data frame x, 1 col data frame y", { observed <- - .convert_xy_to_form_fit(mtcars[, -1], mtcars[, "mpg", drop = FALSE], - remove_intercept = TRUE) + .convert_xy_to_form_fit( + mtcars[, -1], + mtcars[, "mpg", drop = FALSE], + remove_intercept = TRUE + ) expected <- mtcars[, c(2:11, 1)] expect_equal(expected, observed$data) expect_equal(formula("mpg ~ ."), observed$formula, ignore_formula_env = TRUE) @@ -497,9 +515,11 @@ test_that("data frame x, 1 col data frame y", { test_that("matrix x, 1 col matrix y", { observed <- - .convert_xy_to_form_fit(as.matrix(mtcars[,-1]), - as.matrix(mtcars[, "mpg", drop = FALSE]), - remove_intercept = TRUE) + .convert_xy_to_form_fit( + as.matrix(mtcars[, -1]), + as.matrix(mtcars[, "mpg", drop = FALSE]), + remove_intercept = TRUE + ) expected <- mtcars[, c(2:11, 1)] expect_equal(expected, observed$data) expect_equal(formula("mpg ~ ."), observed$formula, ignore_formula_env = TRUE) @@ -509,9 +529,11 @@ test_that("matrix x, 1 col matrix y", { test_that("matrix x, 1 col data frame y", { observed <- - .convert_xy_to_form_fit(as.matrix(mtcars[,-1]), - mtcars[, "mpg", drop = FALSE], - remove_intercept = TRUE) + .convert_xy_to_form_fit( + as.matrix(mtcars[, -1]), + mtcars[, "mpg", drop = FALSE], + remove_intercept = TRUE + ) expected <- mtcars[, c(2:11, 1)] expect_equal(expected, observed$data) expect_equal(formula("mpg ~ ."), observed$formula, ignore_formula_env = TRUE) @@ -521,9 +543,11 @@ test_that("matrix x, 1 col data frame y", { test_that("data frame x, 1 col matrix y", { observed <- - .convert_xy_to_form_fit(mtcars[,-1], - as.matrix(mtcars[, "mpg", drop = FALSE]), - remove_intercept = TRUE) + .convert_xy_to_form_fit( + mtcars[, -1], + as.matrix(mtcars[, "mpg", drop = FALSE]), + remove_intercept = TRUE + ) expected <- mtcars[, c(2:11, 1)] expect_equal(expected, observed$data) expect_equal(formula("mpg ~ ."), observed$formula, ignore_formula_env = TRUE) @@ -533,35 +557,46 @@ test_that("data frame x, 1 col matrix y", { test_that("data frame x, 2 col data frame y", { observed <- - .convert_xy_to_form_fit(mtcars[,-(1:2)], mtcars[, 1:2], - remove_intercept = TRUE) + .convert_xy_to_form_fit( + mtcars[, -(1:2)], + mtcars[, 1:2], + remove_intercept = TRUE + ) expected <- mtcars[, c(3:11, 1:2)] expect_equal(expected, observed$data) - expect_equal(formula("cbind(mpg, cyl) ~ ."), - observed$formula, - ignore_formula_env = TRUE) + expect_equal( + formula("cbind(mpg, cyl) ~ ."), + observed$formula, + ignore_formula_env = TRUE + ) expect_equal(names(mtcars)[-(1:2)], observed$x_var) expect_null(observed$weights) }) test_that("matrix x, 2 col matrix y", { observed <- - .convert_xy_to_form_fit(as.matrix(mtcars[,-(1:2)]), - as.matrix(mtcars[, 1:2]), - remove_intercept = TRUE) + .convert_xy_to_form_fit( + as.matrix(mtcars[, -(1:2)]), + as.matrix(mtcars[, 1:2]), + remove_intercept = TRUE + ) expected <- mtcars[, c(3:11, 1:2)] expect_equal(expected, observed$data) - expect_equal(formula("cbind(mpg, cyl) ~ ."), - observed$formula, - ignore_formula_env = TRUE) + expect_equal( + formula("cbind(mpg, cyl) ~ ."), + observed$formula, + ignore_formula_env = TRUE + ) expect_equal(names(mtcars)[-(1:2)], observed$x_var) expect_null(observed$weights) }) test_that("1 col data frame x, 1 col data frame y", { - observed <- .convert_xy_to_form_fit(mtcars[, 2, drop = FALSE], - mtcars[, 1, drop = FALSE], - remove_intercept = TRUE) + observed <- .convert_xy_to_form_fit( + mtcars[, 2, drop = FALSE], + mtcars[, 1, drop = FALSE], + remove_intercept = TRUE + ) expected <- mtcars[, 2:1] expect_equal(expected, observed$data) expect_equal(formula("mpg ~ ."), observed$formula, ignore_formula_env = TRUE) @@ -613,7 +648,11 @@ test_that("bad args", { ) expect_snapshot( error = TRUE, - .convert_xy_to_form_fit(mtcars[, 1:3], mtcars[, 2:5], remove_intercept = TRUE) + .convert_xy_to_form_fit( + mtcars[, 1:3], + mtcars[, 2:5], + remove_intercept = TRUE + ) ) }) @@ -623,11 +662,16 @@ test_that("convert to matrix", { skip_if_not_installed("modeldata") expect_true(inherits(parsnip::maybe_matrix(mtcars), "matrix")) - expect_true(inherits(parsnip::maybe_matrix(tibble::as_tibble(mtcars)), "matrix")) + expect_true(inherits( + parsnip::maybe_matrix(tibble::as_tibble(mtcars)), + "matrix" + )) expect_true(inherits(parsnip::maybe_matrix(as.matrix(mtcars)), "matrix")) expect_true( - inherits(parsnip::maybe_matrix(Matrix::Matrix(as.matrix(mtcars), sparse = TRUE)), - "dgCMatrix") + inherits( + parsnip::maybe_matrix(Matrix::Matrix(as.matrix(mtcars), sparse = TRUE)), + "dgCMatrix" + ) ) data(ames, package = "modeldata") diff --git a/tests/testthat/test-cubist_rules.R b/tests/testthat/test-cubist_rules.R index 0b3c40e3c..2cadaf221 100644 --- a/tests/testthat/test-cubist_rules.R +++ b/tests/testthat/test-cubist_rules.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/rules expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-decision_tree.R b/tests/testthat/test-decision_tree.R index 3d240c1b2..8fe4557f4 100644 --- a/tests/testthat/test-decision_tree.R +++ b/tests/testthat/test-decision_tree.R @@ -19,7 +19,7 @@ test_that('bad input', { fit(bt, class ~ ., hpc) }) expect_snapshot_error({ - bt <- decision_tree(min_n = 0) |> set_engine("rpart") + bt <- decision_tree(min_n = 0) |> set_engine("rpart") fit(bt, class ~ ., hpc) }) expect_snapshot( @@ -41,13 +41,13 @@ test_that('argument checks for data dimensions', { set_mode("regression") expect_snapshot( - f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) + f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) ) expect_snapshot( xy_fit <- spec |> fit_xy(x = penguins[, -6], y = penguins$body_mass_g) ) - expect_equal(extract_fit_engine(f_fit)$control$minsplit, nrow(penguins)) + expect_equal(extract_fit_engine(f_fit)$control$minsplit, nrow(penguins)) expect_equal(extract_fit_engine(xy_fit)$control$minsplit, nrow(penguins)) spec <- @@ -56,8 +56,7 @@ test_that('argument checks for data dimensions', { set_mode("regression") args <- translate(spec)$method$fit$args - expect_equal(args$min_instances_per_node, rlang::expr(min_rows(1000, x))) - + expect_equal(args$min_instances_per_node, rlang::expr(min_rows(1000, x))) }) test_that("check_args() works", { diff --git a/tests/testthat/test-descriptors.R b/tests/testthat/test-descriptors.R index 4d86875c5..498126174 100644 --- a/tests/testthat/test-descriptors.R +++ b/tests/testthat/test-descriptors.R @@ -5,15 +5,21 @@ hpc <- hpc_data[1:150, c(2:5, 8)] |> as.data.frame() # ------------------------------------------------------------------------------ template <- function(col, pred, ob, lev, fact, dat, x, y) { - lst <- list(.cols = col, .preds = pred, .obs = ob, - .lvls = lev, .facts = fact, .dat = dat, - .x = x, .y = y) + lst <- list( + .cols = col, + .preds = pred, + .obs = ob, + .lvls = lev, + .facts = fact, + .dat = dat, + .x = x, + .y = y + ) Filter(Negate(is.null), lst) } eval_descrs <- function(descrs, not = NULL) { - if (!is.null(not)) { for (descr in not) { descrs[[descr]] <- NULL @@ -30,7 +36,6 @@ class_tab <- table(hpc$class, dnn = NULL) # Should descriptors be created? test_that("requires_descrs", { - # embedded in a function fn <- function() { .cols() @@ -52,12 +57,24 @@ test_that("requires_descrs", { expect_true(parsnip:::requires_descrs(rand_forest(mtry = fn2()))) # descriptors in `eng_args` - expect_false(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = 3))) - expect_false(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = tune()))) - expect_true(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = .obs()))) - expect_false(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = expr(3)))) - expect_true(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = fn()))) - expect_true(parsnip:::requires_descrs(rand_forest() |> set_engine("ranger", arrrg = fn2()))) + expect_false(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = 3) + )) + expect_false(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = tune()) + )) + expect_true(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = .obs()) + )) + expect_false(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = expr(3)) + )) + expect_true(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = fn()) + )) + expect_true(parsnip:::requires_descrs( + rand_forest() |> set_engine("ranger", arrrg = fn2()) + )) # mixed expect_true( @@ -80,18 +97,18 @@ test_that("requires_descrs", { test_that("numeric y and dummy vars", { expect_equal( - template(6, 4, 150, NA, 1, hpc, hpc[-2], hpc[,"input_fields"]), + template(6, 4, 150, NA, 1, hpc, hpc[-2], hpc[, "input_fields"]), eval_descrs(get_descr_form(input_fields ~ ., data = hpc)) ) expect_equal( - template(3, 1, 150, NA, 1, hpc, hpc["class"], hpc[,"input_fields"]), + template(3, 1, 150, NA, 1, hpc, hpc["class"], hpc[, "input_fields"]), eval_descrs(get_descr_form(input_fields ~ class, data = hpc)) ) }) test_that("numeric y and x", { expect_equal( - template(1, 1, 150, NA, 0, hpc, hpc["input_fields"], hpc[,"compounds"]), + template(1, 1, 150, NA, 0, hpc, hpc["input_fields"], hpc[, "compounds"]), eval_descrs(get_descr_form(compounds ~ input_fields, data = hpc)) ) expect_equal( @@ -99,7 +116,7 @@ test_that("numeric y and x", { log_sep <- hpc["input_fields"] log_sep[["input_fields"]] <- log(log_sep[["input_fields"]]) names(log_sep) <- "log(input_fields)" - template(1, 1, 150, NA, 0, hpc, log_sep, hpc[,"compounds"]) + template(1, 1, 150, NA, 0, hpc, log_sep, hpc[, "compounds"]) }, eval_descrs(get_descr_form(compounds ~ log(input_fields), data = hpc)) ) @@ -107,19 +124,19 @@ test_that("numeric y and x", { test_that("factor y", { expect_equal( - template(4, 4, 150, class_tab, 0, hpc, hpc[-5], hpc[,"class"]), + template(4, 4, 150, class_tab, 0, hpc, hpc[-5], hpc[, "class"]), eval_descrs(get_descr_form(class ~ ., data = hpc)) ) expect_equal( - template(1, 1, 150, class_tab, 0, hpc, hpc["compounds"], hpc[,"class"]), + template(1, 1, 150, class_tab, 0, hpc, hpc["compounds"], hpc[, "class"]), eval_descrs(get_descr_form(class ~ compounds, data = hpc)) ) }) test_that("factors all the way down", { - dat <- npk[,1:4] + dat <- npk[, 1:4] expect_equal( - template(7, 3, 24, table(npk$K, dnn = NULL), 3, dat, dat[-4], dat[,"K"]), + template(7, 3, 24, table(npk$K, dnn = NULL), 3, dat, dat[-4], dat[, "K"]), eval_descrs(get_descr_form(K ~ ., data = dat)) ) }) @@ -128,22 +145,26 @@ test_that("weird cases", { # So model.frame ignores - signs in a model formula so class is not removed # prior to model.matrix; otherwise this should have n_cols = 3 expect_equal( - template(3, 4, 150, NA, 1, hpc, hpc[-2], hpc[,"input_fields"]), + template(3, 4, 150, NA, 1, hpc, hpc[-2], hpc[, "input_fields"]), eval_descrs(get_descr_form(input_fields ~ . - class, data = hpc)) ) # Oy ve! Before going to model.matrix, model.frame produces a data frame # with one column and that column is a matrix (with the results from # `poly(input_fields, 3)` - x <- model.frame(~poly(input_fields, 3), hpc) - attributes(x) <- attributes(as.data.frame(x))[c("names", "class", "row.names")] + x <- model.frame(~ poly(input_fields, 3), hpc) + attributes(x) <- attributes(as.data.frame(x))[c( + "names", + "class", + "row.names" + )] expect_equal( - template(3, 1, 150, NA, 0, hpc, x, hpc[,"compounds"]), + template(3, 1, 150, NA, 0, hpc, x, hpc[, "compounds"]), eval_descrs(get_descr_form(compounds ~ poly(input_fields, 3), data = hpc)) ) expect_equal( - template(0, 0, 150, NA, 0, hpc, hpc[,numeric()], hpc[,"compounds"]), + template(0, 0, 150, NA, 0, hpc, hpc[, numeric()], hpc[, "compounds"]), eval_descrs(get_descr_form(compounds ~ 1, data = hpc)) ) }) @@ -160,14 +181,14 @@ test_that("numeric y and dummy vars", { eval_descrs(get_descr_xy(x = hpc[, 1:4], y = hpc$class)) ) - hpc2 <- hpc[,c(4,5,1,2)] + hpc2 <- hpc[, c(4, 5, 1, 2)] rownames(hpc2) <- rownames(hpc2) expect_equal( - template(2, 2, 150, NA, 1, hpc2, hpc[,4:5], hpc[,1:2]), + template(2, 2, 150, NA, 1, hpc2, hpc[, 4:5], hpc[, 1:2]), eval_descrs(get_descr_xy(x = hpc[, 4:5], y = hpc[, 1:2])) ) - hpc3 <- hpc2[,c("num_pending", "class", "compounds")] + hpc3 <- hpc2[, c("num_pending", "class", "compounds")] expect_equal( template(2, 2, 150, NA, 1, hpc3, hpc[, 4:5], hpc[, 1, drop = FALSE]), eval_descrs(get_descr_xy(x = hpc[, 4:5], y = hpc[, 1, drop = FALSE])) @@ -179,9 +200,10 @@ test_that("numeric y and dummy vars", { # Descriptor helpers test_that("can be temporarily overriden at evaluation time", { - scope_n_cols <- function() { - scoped_descrs(list(.cols = function() { 1 })) + scoped_descrs(list(.cols = function() { + 1 + })) .cols() } @@ -193,7 +215,6 @@ test_that("can be temporarily overriden at evaluation time", { # .cols() should now be reset to an error expect_snapshot(error = TRUE, .cols()) - }) @@ -213,5 +234,4 @@ test_that("system-level descriptor tests", { set_engine("xgboost") |> fit(mpg ~ ., data = mtcars) ) - }) diff --git a/tests/testthat/test-discrim_flexible.R b/tests/testthat/test-discrim_flexible.R index 85639dc94..d95d58a5e 100644 --- a/tests/testthat/test-discrim_flexible.R +++ b/tests/testthat/test-discrim_flexible.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/discrim expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-discrim_linear.R b/tests/testthat/test-discrim_linear.R index 85639dc94..d95d58a5e 100644 --- a/tests/testthat/test-discrim_linear.R +++ b/tests/testthat/test-discrim_linear.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/discrim expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-discrim_quad.R b/tests/testthat/test-discrim_quad.R index 85639dc94..d95d58a5e 100644 --- a/tests/testthat/test-discrim_quad.R +++ b/tests/testthat/test-discrim_quad.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/discrim expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-discrim_regularized.R b/tests/testthat/test-discrim_regularized.R index 85639dc94..d95d58a5e 100644 --- a/tests/testthat/test-discrim_regularized.R +++ b/tests/testthat/test-discrim_regularized.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/discrim expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R index 7ccda16d9..b0bc5bbe4 100644 --- a/tests/testthat/test-extract.R +++ b/tests/testthat/test-extract.R @@ -30,7 +30,10 @@ test_that('extract parameter set from model with main and engine parameters', { skip_on_covr() bst_model <- - boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) |> + boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) |> set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) c5_info <- extract_parameter_set_dials(bst_model) @@ -53,7 +56,10 @@ test_that('extract parameter set from model with no loaded implementation', { set_mode("regression") expect_snapshot(error = TRUE, extract_parameter_set_dials(bt_mod)) - expect_snapshot(error = TRUE, extract_parameter_dials(bt_mod, parameter = "min_n")) + expect_snapshot( + error = TRUE, + extract_parameter_dials(bt_mod, parameter = "min_n") + ) }) # ------------------------------------------------------------------------------ @@ -73,7 +79,10 @@ test_that('extract single parameter from model with main and engine parameters', skip_on_covr() bst_model <- - boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) |> + boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) |> set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) expect_equal( @@ -91,7 +100,7 @@ test_that("extract_parameter_dials doesn't error if namespaced args are used", { bst_model <- logistic_reg(mode = "classification", penalty = hardhat::tune()) |> - set_engine("glmnet", family = stats::gaussian("log")) + set_engine("glmnet", family = stats::gaussian("log")) expect_no_condition( extract_parameter_dials(bst_model, parameter = "penalty") diff --git a/tests/testthat/test-failed_models.R b/tests/testthat/test-failed_models.R index d1f3bc86c..c6011581a 100644 --- a/tests/testthat/test-failed_models.R +++ b/tests/testthat/test-failed_models.R @@ -31,12 +31,15 @@ test_that('numeric model', { expect_snapshot(num_res <- predict(lm_mod, hpc_bad[1:11, -1])) expect_equal(num_res, NULL) - expect_snapshot(ci_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "conf_int")) + expect_snapshot( + ci_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "conf_int") + ) expect_equal(ci_res, NULL) - expect_snapshot(pi_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "pred_int")) + expect_snapshot( + pi_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "pred_int") + ) expect_equal(pi_res, NULL) - }) # ------------------------------------------------------------------------------ @@ -45,24 +48,38 @@ test_that('classification model', { log_reg <- logistic_reg() |> set_engine("glm") |> - fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) + fit( + Class ~ log(funded_amnt) + int_rate + big_num, + data = lending_club, + control = ctrl + ) expect_snapshot( cls_res <- - predict(log_reg, lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class)) + predict( + log_reg, + lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class) + ) ) expect_equal(cls_res, NULL) expect_snapshot( prb_res <- - predict(log_reg, lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class), type = "prob") + predict( + log_reg, + lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class), + type = "prob" + ) ) expect_equal(prb_res, NULL) expect_snapshot( ci_res <- - predict(log_reg, lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class), type = "conf_int") + predict( + log_reg, + lending_club |> dplyr::slice(1:7) |> dplyr::select(-Class), + type = "conf_int" + ) ) expect_equal(ci_res, NULL) }) - diff --git a/tests/testthat/test-fit_interfaces.R b/tests/testthat/test-fit_interfaces.R index bba169e6a..440ab9a12 100644 --- a/tests/testthat/test-fit_interfaces.R +++ b/tests/testthat/test-fit_interfaces.R @@ -10,18 +10,25 @@ sprk <- 1:10 class(sprk) <- c(class(sprk), "tbl_spark") tester <- - function(object, formula = NULL, data = NULL, model) - parsnip:::check_interface(formula, data, match.call(expand.dots = TRUE), model) + function(object, formula = NULL, data = NULL, model) { + parsnip:::check_interface( + formula, + data, + match.call(expand.dots = TRUE), + model + ) + } tester_xy <- - function(object, x = NULL, y = NULL, model) + function(object, x = NULL, y = NULL, model) { parsnip:::check_xy_interface(x, y, match.call(expand.dots = TRUE), model) + } test_that('good args', { - expect_equal( tester(NULL, formula = f, data = hpc, model = rmod), "formula") + expect_equal(tester(NULL, formula = f, data = hpc, model = rmod), "formula") expect_equal(tester_xy(NULL, x = hpc, y = hpc, model = rmod), "data.frame") - expect_equal( tester(NULL, f, data = hpc, model = rmod), "formula") - expect_equal( tester(NULL, f, data = sprk, model = rmod), "formula") + expect_equal(tester(NULL, f, data = hpc, model = rmod), "formula") + expect_equal(tester(NULL, f, data = sprk, model = rmod), "formula") }) #test_that('unnamed args', { @@ -30,23 +37,25 @@ test_that('good args', { #}) # test_that('wrong args', { - expect_snapshot(error = TRUE, tester_xy(NULL, x = sprk, y = hpc, model = rmod)) - expect_snapshot(error = TRUE, tester(NULL, f, data = as.matrix(hpc[, 1:4]))) + expect_snapshot( + error = TRUE, + tester_xy(NULL, x = sprk, y = hpc, model = rmod) + ) + expect_snapshot(error = TRUE, tester(NULL, f, data = as.matrix(hpc[, 1:4]))) }) test_that('single column df for issue #129', { - expect_no_condition( lm1 <- linear_reg() |> set_engine("lm") |> - fit_xy(x = mtcars[, 2:4], y = mtcars[,1, drop = FALSE]) + fit_xy(x = mtcars[, 2:4], y = mtcars[, 1, drop = FALSE]) ) expect_no_condition( lm2 <- linear_reg() |> set_engine("lm") |> - fit_xy(x = mtcars[, 2:4], y = as.matrix(mtcars)[,1, drop = FALSE]) + fit_xy(x = mtcars[, 2:4], y = as.matrix(mtcars)[, 1, drop = FALSE]) ) lm3 <- linear_reg() |> @@ -66,36 +75,34 @@ test_that('unknown modes', { ) expect_snapshot( error = TRUE, - fit_xy(mars_spec, x = mtcars[, -1], y = mtcars[,1]) + fit_xy(mars_spec, x = mtcars[, -1], y = mtcars[, 1]) ) expect_snapshot( error = TRUE, - fit_xy(mars_spec, x = lending_club[,1:2], y = lending_club$Class) + fit_xy(mars_spec, x = lending_club[, 1:2], y = lending_club$Class) ) }) test_that("misspecified formula argument", { rec <- structure(list(), class = "recipe") - expect_snapshot(error = TRUE, - fit(linear_reg(), rec, mtcars) - ) - expect_snapshot(error = TRUE, - fit(linear_reg(), "boop", mtcars) - ) + expect_snapshot(error = TRUE, fit(linear_reg(), rec, mtcars)) + expect_snapshot(error = TRUE, fit(linear_reg(), "boop", mtcars)) }) test_that("elapsed time parsnip mods", { lm1 <- linear_reg() |> set_engine("lm") |> - fit_xy(x = mtcars[, 2:4], y = mtcars$mpg, - control = control_parsnip(verbosity = 2L)) + fit_xy( + x = mtcars[, 2:4], + y = mtcars$mpg, + control = control_parsnip(verbosity = 2L) + ) lm2 <- linear_reg() |> set_engine("lm") |> - fit(mpg ~ ., data = mtcars, - control = control_parsnip(verbosity = 2)) + fit(mpg ~ ., data = mtcars, control = control_parsnip(verbosity = 2)) expect_output(print(lm1), "Fit time:") expect_output(print(lm2), "Fit time:") @@ -114,11 +121,17 @@ test_that("elapsed time parsnip mods", { test_that('No loaded engines', { expect_no_condition( - linear_reg() |> fit(mpg ~., data = mtcars) + linear_reg() |> fit(mpg ~ ., data = mtcars) ) - expect_snapshot_error({cubist_rules() |> fit(mpg ~., data = mtcars)}) - expect_snapshot_error({poisson_reg() |> fit(mpg ~., data = mtcars)}) - expect_snapshot_error({cubist_rules(engine = "Cubist") |> fit(mpg ~., data = mtcars)}) + expect_snapshot_error({ + cubist_rules() |> fit(mpg ~ ., data = mtcars) + }) + expect_snapshot_error({ + poisson_reg() |> fit(mpg ~ ., data = mtcars) + }) + expect_snapshot_error({ + cubist_rules(engine = "Cubist") |> fit(mpg ~ ., data = mtcars) + }) }) test_that("fit_xy() can handle attributes on a data.frame outcome (#1060)", { @@ -126,8 +139,9 @@ test_that("fit_xy() can handle attributes on a data.frame outcome (#1060)", { x <- data.frame(x = 1:5) y <- c(2:5, 5) - expect_silent(res <- - fit_xy(lr, x = x, y = data.frame(y = structure(y, label = "hi"))) + expect_silent( + res <- + fit_xy(lr, x = x, y = data.frame(y = structure(y, label = "hi"))) ) expect_equal(res[["fit"]], fit_xy(lr, x, y)[["fit"]], ignore_attr = "label") }) @@ -163,17 +177,23 @@ test_that("overhead of parsnip interface is minimal (#1071)", { bm <- bench::mark( time_engine = lm(mpg ~ ., mtcars), time_parsnip_form = fit(linear_reg(), mpg ~ ., mtcars), - time_parsnip_xy = fit_xy(linear_reg(), mtcars[2:11], mtcars[1]), + time_parsnip_xy = fit_xy(linear_reg(), mtcars[2:11], mtcars[1]), relative = TRUE, check = FALSE ) expect_true( bm$median[2] < 3.5, - label = paste0("parsnip overhead factor (formula interface): ", round(bm$median[2], 4)) + label = paste0( + "parsnip overhead factor (formula interface): ", + round(bm$median[2], 4) + ) ) expect_true( bm$median[3] < 3.75, - label = paste0("parsnip overhead factor (xy interface): ", round(bm$median[3], 4)) + label = paste0( + "parsnip overhead factor (xy interface): ", + round(bm$median[3], 4) + ) ) }) diff --git a/tests/testthat/test-gen_additive_model.R b/tests/testthat/test-gen_additive_model.R index e4fad7f8e..b5ba52be8 100644 --- a/tests/testthat/test-gen_additive_model.R +++ b/tests/testthat/test-gen_additive_model.R @@ -37,9 +37,9 @@ test_that('regression', { mgcv_ci <- predict(mgcv_mod, head(mtcars), type = "link", se.fit = TRUE) expect_equal(f_ci[[".std_error"]], mgcv_ci$se.fit) lower <- - mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + mgcv_ci$fit - + qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit expect_equal(f_ci[[".pred_lower"]], lower) - }) # ------------------------------------------------------------------------------ @@ -69,10 +69,12 @@ test_that('classification', { ) ) mgcv_mod <- - mgcv::gam(Class ~ s(A, k = 10) + B, - data = two_class_dat, - gamma = 1.5, - family = binomial) + mgcv::gam( + Class ~ s(A, k = 10) + B, + data = two_class_dat, + gamma = 1.5, + family = binomial + ) expect_equal(coef(mgcv_mod), coef(extract_fit_engine(f_res))) f_pred <- predict(f_res, head(two_class_dat), type = "prob") @@ -85,14 +87,24 @@ test_that('classification', { f_cls <- predict(f_res, head(two_class_dat), type = "class") expect_true(all(f_cls$.pred_class[mgcv_pred < 0.5] == "Class1")) - f_ci <- predict(f_res, head(two_class_dat), type = "conf_int", std_error = TRUE) - mgcv_ci <- predict(mgcv_mod, head(two_class_dat), type = "link", se.fit = TRUE) + f_ci <- predict( + f_res, + head(two_class_dat), + type = "conf_int", + std_error = TRUE + ) + mgcv_ci <- predict( + mgcv_mod, + head(two_class_dat), + type = "link", + se.fit = TRUE + ) expect_equal(f_ci[[".std_error"]], mgcv_ci$se.fit) lower <- - mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit + mgcv_ci$fit - + qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit lower <- binomial()$linkinv(lower) expect_equal(f_ci[[".pred_lower_Class2"]], lower) - }) test_that("check_args() works", { diff --git a/tests/testthat/test-glm_grouped.R b/tests/testthat/test-glm_grouped.R index dba418d6f..74cf31b8d 100644 --- a/tests/testthat/test-glm_grouped.R +++ b/tests/testthat/test-glm_grouped.R @@ -7,7 +7,11 @@ test_that('correct results for glm_grouped()', { ungrouped <- glm(Admit ~ Gender + Dept, data = ucb_long, family = binomial) expect_no_condition( - grouped <- glm_grouped(Admit ~ Gender + Dept, data = ucb_weighted, weights = ucb_weighted$Freq) + grouped <- glm_grouped( + Admit ~ Gender + Dept, + data = ucb_weighted, + weights = ucb_weighted$Freq + ) ) expect_equal(grouped$df.null, 11) diff --git a/tests/testthat/test-linear_reg.R b/tests/testthat/test-linear_reg.R index b0893f8ea..a4c58e2fe 100644 --- a/tests/testthat/test-linear_reg.R +++ b/tests/testthat/test-linear_reg.R @@ -27,7 +27,6 @@ hpc_basic <- linear_reg() |> set_engine("lm") # ------------------------------------------------------------------------------ test_that('lm execution', { - expect_no_condition( res <- fit( hpc_basic, @@ -108,7 +107,6 @@ test_that('lm execution', { }) test_that('glm execution', { - hpc_glm <- linear_reg() |> set_engine("glm") expect_no_condition( @@ -158,7 +156,6 @@ test_that('glm execution', { control = caught_ctrl ) ) - }) @@ -196,11 +193,10 @@ test_that('lm prediction', { control = ctrl ) - expect_equal(mv_pred, predict(res_mv, hpc[1:5,])) + expect_equal(mv_pred, predict(res_mv, hpc[1:5, ])) }) test_that('glm prediction', { - hpc_glm <- linear_reg() |> set_engine("glm") uni_lm <- glm(compounds ~ input_fields + num_pending + iterations, data = hpc) @@ -225,46 +221,67 @@ test_that('glm prediction', { ) expect_equal(inl_pred, predict(res_form, hpc[1:5, ])$.pred) - }) test_that('lm intervals', { - stats_lm <- lm(compounds ~ input_fields + iterations + num_pending, - data = hpc) - confidence_lm <- predict(stats_lm, newdata = hpc[1:5, ], - level = 0.93, interval = "confidence") - prediction_lm <- predict(stats_lm, newdata = hpc[1:5, ], - level = 0.93, interval = "prediction") + stats_lm <- lm( + compounds ~ input_fields + iterations + num_pending, + data = hpc + ) + confidence_lm <- predict( + stats_lm, + newdata = hpc[1:5, ], + level = 0.93, + interval = "confidence" + ) + prediction_lm <- predict( + stats_lm, + newdata = hpc[1:5, ], + level = 0.93, + interval = "prediction" + ) res_xy <- fit_xy( - linear_reg() |> set_engine("lm"), + linear_reg() |> set_engine("lm"), x = hpc[, num_pred], y = hpc$compounds, control = ctrl ) confidence_parsnip <- - predict(res_xy, - new_data = hpc[1:5,], - type = "conf_int", - level = 0.93) + predict(res_xy, new_data = hpc[1:5, ], type = "conf_int", level = 0.93) - expect_equal(confidence_parsnip$.pred_lower, confidence_lm[, "lwr"], ignore_attr = TRUE) - expect_equal(confidence_parsnip$.pred_upper, confidence_lm[, "upr"], ignore_attr = TRUE) + expect_equal( + confidence_parsnip$.pred_lower, + confidence_lm[, "lwr"], + ignore_attr = TRUE + ) + expect_equal( + confidence_parsnip$.pred_upper, + confidence_lm[, "upr"], + ignore_attr = TRUE + ) prediction_parsnip <- - predict(res_xy, - new_data = hpc[1:5,], - type = "pred_int", - level = 0.93) + predict(res_xy, new_data = hpc[1:5, ], type = "pred_int", level = 0.93) - expect_equal(prediction_parsnip$.pred_lower, prediction_lm[, "lwr"], ignore_attr = TRUE) - expect_equal(prediction_parsnip$.pred_upper, prediction_lm[, "upr"], ignore_attr = TRUE) + expect_equal( + prediction_parsnip$.pred_lower, + prediction_lm[, "lwr"], + ignore_attr = TRUE + ) + expect_equal( + prediction_parsnip$.pred_upper, + prediction_lm[, "upr"], + ignore_attr = TRUE + ) }) test_that('glm intervals', { - stats_glm <- glm(compounds ~ input_fields + iterations + num_pending, - data = hpc) + stats_glm <- glm( + compounds ~ input_fields + iterations + num_pending, + data = hpc + ) pred_glm <- predict(stats_glm, newdata = hpc[1:5, ], se.fit = TRUE) t_val <- qt(0.035, df = stats_glm$df.residual, lower.tail = FALSE) lower_glm <- pred_glm$fit - t_val * pred_glm$se.fit @@ -274,21 +291,17 @@ test_that('glm intervals', { upper_glm <- stats_glm$family$linkinv(upper_glm) res_xy <- fit_xy( - linear_reg() |> set_engine("glm"), + linear_reg() |> set_engine("glm"), x = hpc[, num_pred], y = hpc$compounds, control = ctrl ) confidence_parsnip <- - predict(res_xy, - new_data = hpc[1:5,], - type = "conf_int", - level = 0.93) + predict(res_xy, new_data = hpc[1:5, ], type = "conf_int", level = 0.93) expect_equal(confidence_parsnip$.pred_lower, lower_glm) expect_equal(confidence_parsnip$.pred_upper, upper_glm) - }) @@ -299,9 +312,7 @@ test_that('newdata error trapping', { y = hpc$input_fields, control = ctrl ) - expect_snapshot(error = TRUE, - predict(res_xy, newdata = hpc[1:3, num_pred]) - ) + expect_snapshot(error = TRUE, predict(res_xy, newdata = hpc[1:3, num_pred])) }) test_that('show engine', { @@ -318,17 +329,17 @@ test_that('lm can handle rankdeficient predictions', { ) data <- data.frame( - y = c(1,2,3,4), - x1 = c(1,1,2,3), - x2 = c(3,4,5,2), - x3 = c(4,2,6,0), - x4 = c(2,1,3,0) + y = c(1, 2, 3, 4), + x1 = c(1, 1, 2, 3), + x2 = c(3, 4, 5, 2), + x3 = c(4, 2, 6, 0), + x4 = c(2, 1, 3, 0) ) data2 <- data.frame( - x1 = c(3,2,1,3), - x2 = c(3,2,1,4), - x3 = c(3,4,5,1), - x4 = c(0,0,2,3) + x1 = c(3, 2, 1, 3), + x2 = c(3, 2, 1, 4), + x3 = c(3, 4, 5, 1), + x4 = c(0, 0, 2, 3) ) expect_snapshot( @@ -395,7 +406,6 @@ test_that("prevent using a Poisson family", { # ------------------------------------------------------------------------------ test_that("tunables", { - expect_snapshot( linear_reg() |> tunable() @@ -423,6 +433,4 @@ test_that("tunables", { set_engine("keras") |> tunable() ) - }) - diff --git a/tests/testthat/test-linear_reg_keras.R b/tests/testthat/test-linear_reg_keras.R index 071d9725e..d3123ed9d 100644 --- a/tests/testthat/test-linear_reg_keras.R +++ b/tests/testthat/test-linear_reg_keras.R @@ -29,7 +29,7 @@ test_that('model fitting', { fit_xy( basic_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) ) @@ -41,7 +41,7 @@ test_that('model fitting', { fit_xy( basic_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) ) @@ -65,7 +65,7 @@ test_that('model fitting', { fit_xy( ridge_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) ) @@ -78,7 +78,6 @@ test_that('model fitting', { control = ctrl ) ) - }) @@ -94,18 +93,18 @@ test_that('regression prediction', { fit_xy( basic_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) keras_pred <- - predict(extract_fit_engine(lm_fit), as.matrix(hpc[1:3,2:4])) + predict(extract_fit_engine(lm_fit), as.matrix(hpc[1:3, 2:4])) colnames(keras_pred) <- ".pred" keras_pred <- keras_pred |> as_tibble() - parsnip_pred <- predict(lm_fit, hpc[1:3,2:4]) + parsnip_pred <- predict(lm_fit, hpc[1:3, 2:4]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) set.seed(257) @@ -113,15 +112,14 @@ test_that('regression prediction', { fit_xy( ridge_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) - keras_pred <- predict(extract_fit_engine(rr_fit), as.matrix(hpc[1:3,2:4])) + keras_pred <- predict(extract_fit_engine(rr_fit), as.matrix(hpc[1:3, 2:4])) colnames(keras_pred) <- ".pred" keras_pred <- tibble::as_tibble(keras_pred) - parsnip_pred <- predict(rr_fit, hpc[1:3,2:4]) + parsnip_pred <- predict(rr_fit, hpc[1:3, 2:4]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - }) diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 171051c30..8fb8e7832 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -35,7 +35,7 @@ test_that('linear quantile regression via quantreg - single quantile', { ### - one_quant_one_row <- predict(one_quant, new_data = sac_test[1,]) + one_quant_one_row <- predict(one_quant, new_data = sac_test[1, ]) expect_true(nrow(one_quant_one_row) == 1L) expect_named(one_quant_one_row, ".pred_quantile") expect_true(is.list(one_quant_one_row[[1]])) @@ -49,8 +49,11 @@ test_that('linear quantile regression via quantreg - single quantile', { one_quant_one_row_df <- as_tibble(one_quant_one_row$.pred_quantile) expect_s3_class(one_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) - expect_named(one_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) - expect_true(nrow(one_quant_one_row_df) == nrow(sac_test[1,]) * 1) + expect_named( + one_quant_one_row_df, + c(".pred_quantile", ".quantile_levels", ".row") + ) + expect_true(nrow(one_quant_one_row_df) == nrow(sac_test[1, ]) * 1) }) test_that('linear quantile regression via quantreg - multiple quantiles', { @@ -61,7 +64,7 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { ten_quant <- linear_reg() |> set_engine("quantreg") |> - set_mode("quantile regression", quantile_levels = (0:9)/9) |> + set_mode("quantile regression", quantile_levels = (0:9) / 9) |> fit(price ~ ., data = sac_train) expect_s3_class(ten_quant, c("_rq", "model_fit")) @@ -78,7 +81,10 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { ) expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric") expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) - expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9) + expect_identical( + attr(ten_quant_pred$.pred_quantile, "quantile_levels"), + (0:9) / 9 + ) ten_quant_df <- as_tibble(ten_quant_pred$.pred_quantile) expect_s3_class(ten_quant_df, c("tbl_df", "tbl", "data.frame")) @@ -86,13 +92,17 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) expect_snapshot( - ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9), + ten_quant_pred <- predict( + ten_quant, + new_data = sac_test, + quantile_levels = (0:9) / 9 + ), error = TRUE ) ### - ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,]) + ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1, ]) expect_true(nrow(ten_quant_one_row) == 1L) expect_named(ten_quant_one_row, ".pred_quantile") expect_true(is.list(ten_quant_one_row[[1]])) @@ -104,14 +114,14 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L) expect_identical( attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), - (0:9)/9 + (0:9) / 9 ) ten_quant_one_row_df <- as_tibble(ten_quant_one_row$.pred_quantile) expect_s3_class(ten_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) - expect_named(ten_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) - expect_true(nrow(ten_quant_one_row_df) == nrow(sac_test[1,]) * 10) + expect_named( + ten_quant_one_row_df, + c(".pred_quantile", ".quantile_levels", ".row") + ) + expect_true(nrow(ten_quant_one_row_df) == nrow(sac_test[1, ]) * 10) }) - - - diff --git a/tests/testthat/test-logistic_reg.R b/tests/testthat/test-logistic_reg.R index 051dd7cd0..a30887171 100644 --- a/tests/testthat/test-logistic_reg.R +++ b/tests/testthat/test-logistic_reg.R @@ -14,7 +14,10 @@ test_that('updating', { test_that('bad input', { expect_snapshot(error = TRUE, logistic_reg(mode = "regression")) - expect_snapshot(error = TRUE, translate(logistic_reg(mixture = 0.5) |> set_engine(engine = "LiblineaR"))) + expect_snapshot( + error = TRUE, + translate(logistic_reg(mixture = 0.5) |> set_engine(engine = "LiblineaR")) + ) expect_snapshot( res <- @@ -34,8 +37,6 @@ lc_basic <- logistic_reg() |> set_engine("glm") ll_basic <- logistic_reg() |> set_engine("LiblineaR") test_that('glm execution', { - - # passes interactively but not on R CMD check # expect_no_condition( # res <- fit( @@ -95,12 +96,18 @@ test_that('glm prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(classes_xy), newdata = lending_club[1:7, num_pred], type = "response") + xy_pred <- predict( + extract_fit_engine(classes_xy), + newdata = lending_club[1:7, num_pred], + type = "response" + ) xy_pred <- ifelse(xy_pred >= 0.5, "good", "bad") xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "class")$.pred_class) - + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred], type = "class")$.pred_class + ) }) test_that('glm probabilities', { @@ -111,22 +118,28 @@ test_that('glm probabilities', { control = ctrl ) - xy_pred <- unname(predict(extract_fit_engine(classes_xy), - newdata = lending_club[1:7, num_pred], - type = "response")) + xy_pred <- unname(predict( + extract_fit_engine(classes_xy), + newdata = lending_club[1:7, num_pred], + type = "response" + )) xy_pred <- tibble(.pred_bad = 1 - xy_pred, .pred_good = xy_pred) - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "prob")) + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred], type = "prob") + ) one_row <- predict(classes_xy, lending_club[1, num_pred], type = "prob") - expect_equal(xy_pred[1,], one_row) - + expect_equal(xy_pred[1, ], one_row) }) - test_that('glm intervals', { - stats_glm <- glm(Class ~ log(funded_amnt) + int_rate, data = lending_club, - family = binomial) + stats_glm <- glm( + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + family = binomial + ) pred_glm <- predict(stats_glm, newdata = lending_club[1:5, ], se.fit = TRUE) t_val <- qt(0.035, df = stats_glm$df.residual, lower.tail = FALSE) lower_glm <- pred_glm$fit - t_val * pred_glm$se.fit @@ -143,22 +156,22 @@ test_that('glm intervals', { ) confidence_parsnip <- - predict(res, - new_data = lending_club[1:5,], - type = "conf_int", - level = 0.93, - std_error = TRUE) + predict( + res, + new_data = lending_club[1:5, ], + type = "conf_int", + level = 0.93, + std_error = TRUE + ) expect_equal(confidence_parsnip$.pred_lower_good, lower_glm) expect_equal(confidence_parsnip$.pred_upper_good, upper_glm) expect_equal(confidence_parsnip$.pred_lower_bad, 1 - upper_glm) expect_equal(confidence_parsnip$.pred_upper_bad, 1 - lower_glm) expect_equal(confidence_parsnip$.std_error, pred_glm$se.fit) - }) test_that('liblinear execution', { - skip_if_not_installed("LiblineaR") expect_no_condition( @@ -206,12 +219,9 @@ test_that('liblinear execution', { y = lending_club$total_bal_il ) ) - - }) test_that('liblinear prediction', { - skip_if_not_installed("LiblineaR") classes_xy <- fit_xy( @@ -221,14 +231,18 @@ test_that('liblinear prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(classes_xy), newx = lending_club[1:7, num_pred]) + xy_pred <- predict( + extract_fit_engine(classes_xy), + newx = lending_club[1:7, num_pred] + ) xy_pred <- xy_pred$predictions - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "class")$.pred_class) - + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred], type = "class")$.pred_class + ) }) test_that('liblinear probabilities', { - skip_if_not_installed("LiblineaR") classes_xy <- fit_xy( @@ -238,17 +252,20 @@ test_that('liblinear probabilities', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(classes_xy), - newx = lending_club[1:7, num_pred], - proba = TRUE) + xy_pred <- predict( + extract_fit_engine(classes_xy), + newx = lending_club[1:7, num_pred], + proba = TRUE + ) xy_pred <- as_tibble(xy_pred$probabilities) - xy_pred <- tibble(.pred_good = xy_pred$good, - .pred_bad = xy_pred$bad) - expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "prob")) + xy_pred <- tibble(.pred_good = xy_pred$good, .pred_bad = xy_pred$bad) + expect_equal( + xy_pred, + predict(classes_xy, lending_club[1:7, num_pred], type = "prob") + ) one_row <- predict(classes_xy, lending_club[1, num_pred], type = "prob") - expect_equal(xy_pred[1,], one_row) - + expect_equal(xy_pred[1, ], one_row) }) test_that("check_args() works", { @@ -293,7 +310,6 @@ test_that("check_args() works", { # ------------------------------------------------------------------------------ test_that("tunables", { - expect_snapshot( logistic_reg() |> tunable() @@ -315,5 +331,4 @@ test_that("tunables", { set_engine("keras") |> tunable() ) - }) diff --git a/tests/testthat/test-logistic_reg_keras.R b/tests/testthat/test-logistic_reg_keras.R index ef697b347..4b2479601 100644 --- a/tests/testthat/test-logistic_reg_keras.R +++ b/tests/testthat/test-logistic_reg_keras.R @@ -27,7 +27,7 @@ test_that('model fitting', { dplyr::sample_n(500) |> dplyr::ungroup() |> dplyr::select(Class, funded_amnt, int_rate) - dat <- dat[order(runif(nrow(dat))),] + dat <- dat[order(runif(nrow(dat))), ] tr_dat <- dat[1:995, ] te_dat <- dat[996:1000, ] @@ -37,11 +37,11 @@ test_that('model fitting', { expect_no_condition( fit1 <- fit_xy( - basic_mod, - control = ctrl, - x = tr_dat[, -1], - y = tr_dat$Class - ) + basic_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ) ) set_tf_seed(257) @@ -88,7 +88,6 @@ test_that('model fitting', { control = ctrl ) ) - }) @@ -104,7 +103,7 @@ test_that('classification prediction', { dplyr::sample_n(500) |> dplyr::ungroup() |> dplyr::select(Class, funded_amnt, int_rate) - dat <- dat[order(runif(nrow(dat))),] + dat <- dat[order(runif(nrow(dat))), ] tr_dat <- dat[1:995, ] te_dat <- dat[996:1000, ] @@ -124,7 +123,9 @@ test_that('classification prediction', { predict(lr_fit$fit, as.matrix(te_dat[, -1])) keras_pred <- tibble::tibble(.pred_class = apply(keras_raw, 1, which.max)) |> - dplyr::mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl)) + dplyr::mutate( + .pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl) + ) parsnip_pred <- predict(lr_fit, te_dat[, -1]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) @@ -145,7 +146,6 @@ test_that('classification prediction', { mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -1]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - }) @@ -161,7 +161,7 @@ test_that('classification probabilities', { dplyr::sample_n(500) |> dplyr::ungroup() |> dplyr::select(Class, funded_amnt, int_rate) - dat <- dat[order(runif(nrow(dat))),] + dat <- dat[order(runif(nrow(dat))), ] tr_dat <- dat[1:995, ] te_dat <- dat[996:1000, ] @@ -201,5 +201,4 @@ test_that('classification probabilities', { parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - }) diff --git a/tests/testthat/test-mars.R b/tests/testthat/test-mars.R index e4f925a42..68a38bf09 100644 --- a/tests/testthat/test-mars.R +++ b/tests/testthat/test-mars.R @@ -22,7 +22,10 @@ test_that('updating', { }) test_that('bad input', { - expect_snapshot(error = TRUE, translate(mars(mode = "regression") |> set_engine())) + expect_snapshot( + error = TRUE, + translate(mars(mode = "regression") |> set_engine()) + ) expect_snapshot(error = TRUE, translate(mars() |> set_engine("wat?"))) }) @@ -78,26 +81,45 @@ test_that('mars execution', { ) ) parsnip:::load_libs(res, attach = TRUE) - }) test_that('mars prediction', { skip_if_not_installed("earth") - uni_pred <- c(30.1466666666667, 30.1466666666667, 30.1466666666667, - 30.1466666666667, 30.1466666666667) - inl_pred <- c(538.268789262046, 141.024903718634, 141.024903718634, - 141.024903718634, 141.024903718634) + uni_pred <- c( + 30.1466666666667, + 30.1466666666667, + 30.1466666666667, + 30.1466666666667, + 30.1466666666667 + ) + inl_pred <- c( + 538.268789262046, + 141.024903718634, + 141.024903718634, + 141.024903718634, + 141.024903718634 + ) mv_pred <- structure( - list(compounds = - c(371.334864384913, 129.475162245595, 256.094366313268, - 129.475162245595, 129.475162245595), - input_fields = - c(430.476046435458, 158.833790342308, 218.07635084308, - 158.833790342308, 158.833790342308) + list( + compounds = c( + 371.334864384913, + 129.475162245595, + 256.094366313268, + 129.475162245595, + 129.475162245595 + ), + input_fields = c( + 430.476046435458, + 158.833790342308, + 218.07635084308, + 158.833790342308, + 158.833790342308 + ) ), - class = "data.frame", row.names = c(NA, -5L) + class = "data.frame", + row.names = c(NA, -5L) ) res_xy <- fit_xy( @@ -125,7 +147,7 @@ test_that('mars prediction', { ) expect_equal( setNames(mv_pred, paste0(".pred_", names(mv_pred))) |> as.data.frame(), - predict(res_mv, hpc[1:5,]) |> as.data.frame() + predict(res_mv, hpc[1:5, ]) |> as.data.frame() ) }) @@ -145,38 +167,52 @@ test_that('submodel prediction', { parsnip:::load_libs(reg_fit$spec, quiet = TRUE, attach = TRUE) tmp_reg <- extract_fit_engine(reg_fit) tmp_reg$call[["pmethod"]] <- eval_tidy(tmp_reg$call[["pmethod"]]) - tmp_reg$call[["keepxy"]] <- eval_tidy(tmp_reg$call[["keepxy"]]) - tmp_reg$call[["nprune"]] <- eval_tidy(tmp_reg$call[["nprune"]]) - + tmp_reg$call[["keepxy"]] <- eval_tidy(tmp_reg$call[["keepxy"]]) + tmp_reg$call[["nprune"]] <- eval_tidy(tmp_reg$call[["nprune"]]) pruned_reg <- update(tmp_reg, nprune = 5) - pruned_reg_pred <- predict(pruned_reg, mtcars[1:4, -1])[,1] + pruned_reg_pred <- predict(pruned_reg, mtcars[1:4, -1])[, 1] mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], num_terms = 5) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred"]], pruned_reg_pred) full_churn <- wa_churn[complete.cases(wa_churn), ] - vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") + vars <- c( + "female", + "tenure", + "total_charges", + "phone_service", + "monthly_charges" + ) class_fit <- - mars(mode = "classification", prune_method = "none") |> + mars(mode = "classification", prune_method = "none") |> set_engine("earth", keepxy = TRUE) |> - fit(churn ~ ., - data = full_churn[-(1:4), c("churn", vars)]) + fit(churn ~ ., data = full_churn[-(1:4), c("churn", vars)]) cls_fit <- extract_fit_engine(class_fit) cls_fit$call[["pmethod"]] <- eval_tidy(cls_fit$call[["pmethod"]]) - cls_fit$call[["keepxy"]] <- eval_tidy(cls_fit$call[["keepxy"]]) - cls_fit$call[["glm"]] <- eval_tidy(cls_fit$call[["glm"]]) + cls_fit$call[["keepxy"]] <- eval_tidy(cls_fit$call[["keepxy"]]) + cls_fit$call[["glm"]] <- eval_tidy(cls_fit$call[["glm"]]) pruned_cls <- update(cls_fit, nprune = 5) - pruned_cls_pred <- predict(pruned_cls, full_churn[1:4, vars], type = "response")[,1] - - mp_res <- multi_predict(class_fit, new_data = full_churn[1:4, vars], num_terms = 5, type = "prob") + pruned_cls_pred <- predict( + pruned_cls, + full_churn[1:4, vars], + type = "response" + )[, 1] + + mp_res <- multi_predict( + class_fit, + new_data = full_churn[1:4, vars], + num_terms = 5, + type = "prob" + ) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], pruned_cls_pred) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, multi_predict(reg_fit, newdata = mtcars[1:4, -1], num_terms = 5) ) }) @@ -189,16 +225,25 @@ test_that('classification', { expect_no_condition( glm_mars <- - mars(mode = "classification") |> + mars(mode = "classification") |> set_engine("earth") |> - fit(Class ~ ., data = modeldata::lending_club[-(1:5),]) + fit(Class ~ ., data = modeldata::lending_club[-(1:5), ]) ) expect_true(!is.null(extract_fit_engine(glm_mars)$glm.list)) - parsnip_pred <- predict(glm_mars, new_data = lending_club[1:5, -ncol(lending_club)], type = "prob") + parsnip_pred <- predict( + glm_mars, + new_data = lending_club[1:5, -ncol(lending_club)], + type = "prob" + ) earth_pred <- - c(0.95631355972526, 0.971917781277731, 0.894245392500336, 0.962667553751077, - 0.985827594261896) + c( + 0.95631355972526, + 0.971917781277731, + 0.894245392500336, + 0.962667553751077, + 0.985827594261896 + ) expect_equal(parsnip_pred$.pred_good, earth_pred) }) diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index ede7a6fba..3b9cf5d8b 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -25,20 +25,16 @@ test_that('parsnip objects', { error = TRUE, multi_predict(extract_fit_engine(mars_fit), mtcars) ) - }) test_that('other objects', { - expect_false(has_multi_predict(NULL)) expect_false(has_multi_predict(NA)) - }) # ------------------------------------------------------------------------------ test_that('S3 method dispatch/registration', { - expect_no_condition( res <- null_model() |> @@ -58,7 +54,6 @@ test_that('S3 method dispatch/registration', { tidy() ) expect_true(tibble::is_tibble(res)) - }) # ------------------------------------------------------------------------------ @@ -102,7 +97,6 @@ test_that('correct mtry', { expect_equal(max_mtry_formula(2, f_2, ames), 2) expect_equal(max_mtry_formula(200, f_3, data = mtcars), ncol(mtcars) - 2) - }) # ---------------------------------------------------------------------------- @@ -171,17 +165,15 @@ test_that('arguments can be passed to model spec inside function', { test_that('set_engine works as a generic', { - expect_snapshot(error = TRUE, - set_engine(mtcars, "rpart") - ) - + expect_snapshot(error = TRUE, set_engine(mtcars, "rpart")) }) test_that('check_for_newdata points out correct context', { - fn <- function(...) {check_for_newdata(...); invisible()} - expect_snapshot(error = TRUE, - fn(newdata = "boop!") - ) + fn <- function(...) { + check_for_newdata(...) + invisible() + } + expect_snapshot(error = TRUE, fn(newdata = "boop!")) }) test_that('check_outcome works as expected', { @@ -207,7 +199,7 @@ test_that('check_outcome works as expected', { expect_snapshot( error = TRUE, - fit(reg_spec, ~ mpg, mtcars) + fit(reg_spec, ~mpg, mtcars) ) expect_snapshot( @@ -237,7 +229,7 @@ test_that('check_outcome works as expected', { expect_snapshot( error = TRUE, - fit(class_spec, ~ mpg, mtcars) + fit(class_spec, ~mpg, mtcars) ) # Fake specification to avoid having to load {censored} @@ -264,26 +256,28 @@ test_that('obtaining prediction columns', { lr_fit <- logistic_reg() |> fit(Class ~ ., data = two_class_dat) expect_equal( .get_prediction_column_names(lr_fit), - list(estimate = ".pred_class", - probabilities = c(".pred_Class1", ".pred_Class2")) + list( + estimate = ".pred_class", + probabilities = c(".pred_Class1", ".pred_Class2") + ) ) expect_equal( .get_prediction_column_names(lr_fit, syms = TRUE), - list(estimate = list(quote(.pred_class)), - probabilities = list(quote(.pred_Class1), quote(.pred_Class2))) + list( + estimate = list(quote(.pred_class)), + probabilities = list(quote(.pred_Class1), quote(.pred_Class2)) + ) ) ### regression ols_fit <- linear_reg() |> fit(mpg ~ ., data = mtcars) expect_equal( .get_prediction_column_names(ols_fit), - list(estimate = ".pred", - probabilities = character(0)) + list(estimate = ".pred", probabilities = character(0)) ) expect_equal( .get_prediction_column_names(ols_fit, syms = TRUE), - list(estimate = list(quote(.pred)), - probabilities = list()) + list(estimate = list(quote(.pred)), probabilities = list()) ) ### censored regression @@ -301,7 +295,6 @@ test_that('obtaining prediction columns', { .get_prediction_column_names(unk_fit), error = TRUE ) - }) @@ -343,4 +336,3 @@ test_that('register local models', { expect_snapshot(my_model() |> translate("my_engine")) }) - diff --git a/tests/testthat/test-mlp.R b/tests/testthat/test-mlp.R index 3aaf37a4e..c2f769b6c 100644 --- a/tests/testthat/test-mlp.R +++ b/tests/testthat/test-mlp.R @@ -12,7 +12,10 @@ test_that('updating', { test_that('bad input', { expect_snapshot(error = TRUE, mlp(mode = "time series")) - expect_snapshot(error = TRUE, translate(mlp(mode = "classification") |> set_engine("wat?"))) + expect_snapshot( + error = TRUE, + translate(mlp(mode = "classification") |> set_engine("wat?")) + ) expect_warning( translate(mlp(mode = "regression") |> set_engine("nnet", formula = y ~ x)), class = "parsnip_protected_arg_warning" @@ -38,8 +41,8 @@ test_that("more activations for brulee", { set.seed(122) in_train <- sample(1:nrow(ames), 2000) - ames_train <- ames[ in_train,] - ames_test <- ames[-in_train,] + ames_train <- ames[in_train, ] + ames_test <- ames[-in_train, ] set.seed(1) fit <- @@ -47,9 +50,12 @@ test_that("more activations for brulee", { mlp(penalty = 0.10, activation = "softplus") |> set_mode("regression") |> set_engine("brulee") |> - fit_xy(x = as.matrix(ames_train[, c("Longitude", "Latitude")]), - y = ames_train$Sale_Price), - silent = TRUE) + fit_xy( + x = as.matrix(ames_train[, c("Longitude", "Latitude")]), + y = ames_train$Sale_Price + ), + silent = TRUE + ) expect_true(inherits(fit$fit, "brulee_mlp")) }) @@ -88,7 +94,6 @@ test_that("check_args() works", { # ------------------------------------------------------------------------------ test_that("tunables", { - expect_snapshot( mlp() |> set_engine("brulee") |> @@ -111,5 +116,4 @@ test_that("tunables", { set_engine("keras") |> tunable() ) - }) diff --git a/tests/testthat/test-mlp_keras.R b/tests/testthat/test-mlp_keras.R index 89dd05f25..b0ca441b0 100644 --- a/tests/testthat/test-mlp_keras.R +++ b/tests/testthat/test-mlp_keras.R @@ -27,7 +27,6 @@ test_that('keras execution, classification', { ) ) - expect_false(has_multi_predict(res)) expect_equal(multi_predict_args(res), NA_character_) @@ -69,7 +68,10 @@ test_that('keras classification prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), x = as.matrix(hpc[1:8, num_pred])) + xy_pred <- predict( + extract_fit_engine(xy_fit), + x = as.matrix(hpc[1:8, num_pred]) + ) if (tensorflow::tf_version() <= package_version("2.0.0")) { # -1 to assign with keras' zero indexing xy_pred <- apply(xy_pred, 1, which.max) - 1 @@ -78,7 +80,12 @@ test_that('keras classification prediction', { } xy_pred <- factor(levels(hpc$class)[xy_pred + 1], levels = levels(hpc$class)) - expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) + expect_equal( + xy_pred, + predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")[[ + ".pred_class" + ]] + ) keras::backend()$clear_session() @@ -89,8 +96,10 @@ test_that('keras classification prediction', { control = ctrl ) - - form_pred <- predict(extract_fit_engine(form_fit), x = as.matrix(hpc[1:8, num_pred])) + form_pred <- predict( + extract_fit_engine(form_fit), + x = as.matrix(hpc[1:8, num_pred]) + ) if (tensorflow::tf_version() <= package_version("2.0.0")) { # -1 to assign with keras' zero indexing form_pred <- apply(form_pred, 1, which.max) - 1 @@ -98,8 +107,16 @@ test_that('keras classification prediction', { form_pred <- form_pred |> keras::k_argmax() |> as.integer() } - form_pred <- factor(levels(hpc$class)[form_pred + 1], levels = levels(hpc$class)) - expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")[[".pred_class"]]) + form_pred <- factor( + levels(hpc$class)[form_pred + 1], + levels = levels(hpc$class) + ) + expect_equal( + form_pred, + predict(form_fit, new_data = hpc[1:8, num_pred], type = "class")[[ + ".pred_class" + ]] + ) keras::backend()$clear_session() }) @@ -117,10 +134,16 @@ test_that('keras classification probabilities', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), x = as.matrix(hpc[1:8, num_pred])) + xy_pred <- predict( + extract_fit_engine(xy_fit), + x = as.matrix(hpc[1:8, num_pred]) + ) colnames(xy_pred) <- paste0(".pred_", levels(hpc$class)) xy_pred <- as_tibble(xy_pred) - expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "prob")) + expect_equal( + xy_pred, + predict(xy_fit, new_data = hpc[1:8, num_pred], type = "prob") + ) keras::backend()$clear_session() @@ -131,10 +154,16 @@ test_that('keras classification probabilities', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), x = as.matrix(hpc[1:8, num_pred])) + form_pred <- predict( + extract_fit_engine(form_fit), + x = as.matrix(hpc[1:8, num_pred]) + ) colnames(form_pred) <- paste0(".pred_", levels(hpc$class)) form_pred <- as_tibble(form_pred) - expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred], type = "prob")) + expect_equal( + form_pred, + predict(form_fit, new_data = hpc[1:8, num_pred], type = "prob") + ) keras::backend()$clear_session() }) @@ -194,8 +223,14 @@ test_that('keras regression prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), x = as.matrix(mtcars[1:8, c("cyl", "disp")]))[,1] - expect_equal(xy_pred, predict(xy_fit, new_data = mtcars[1:8, c("cyl", "disp")])[[".pred"]]) + xy_pred <- predict( + extract_fit_engine(xy_fit), + x = as.matrix(mtcars[1:8, c("cyl", "disp")]) + )[, 1] + expect_equal( + xy_pred, + predict(xy_fit, new_data = mtcars[1:8, c("cyl", "disp")])[[".pred"]] + ) keras::backend()$clear_session() @@ -206,8 +241,14 @@ test_that('keras regression prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), x = as.matrix(mtcars[1:8, c("cyl", "disp")]))[,1] - expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, c("cyl", "disp")])[[".pred"]]) + form_pred <- predict( + extract_fit_engine(form_fit), + x = as.matrix(mtcars[1:8, c("cyl", "disp")]) + )[, 1] + expect_equal( + form_pred, + predict(form_fit, new_data = mtcars[1:8, c("cyl", "disp")])[[".pred"]] + ) keras::backend()$clear_session() }) @@ -220,13 +261,16 @@ test_that('multivariate nnet formula', { skip_if(!is_tf_ok()) nnet_form <- - mlp(mode = "regression", hidden_units = 3, penalty = 0.01) |> + mlp(mode = "regression", hidden_units = 3, penalty = 0.01) |> set_engine("keras", verbose = 0) |> parsnip::fit( cbind(V1, V2, V3) ~ ., - data = nn_dat[-(1:5),] + data = nn_dat[-(1:5), ] ) - expect_equal(length(unlist(keras::get_weights(extract_fit_engine(nnet_form)))), 24) + expect_equal( + length(unlist(keras::get_weights(extract_fit_engine(nnet_form)))), + 24 + ) nnet_form_pred <- predict(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(names(nnet_form_pred), paste0(".pred_", c("V1", "V2", "V3"))) @@ -234,17 +278,19 @@ test_that('multivariate nnet formula', { keras::backend()$clear_session() nnet_xy <- - mlp(mode = "regression", hidden_units = 3, penalty = 0.01) |> + mlp(mode = "regression", hidden_units = 3, penalty = 0.01) |> set_engine("keras", verbose = 0) |> parsnip::fit_xy( x = nn_dat[-(1:5), -(1:3)], - y = nn_dat[-(1:5), 1:3 ] + y = nn_dat[-(1:5), 1:3] ) - expect_equal(length(unlist(keras::get_weights(extract_fit_engine(nnet_xy)))), 24) + expect_equal( + length(unlist(keras::get_weights(extract_fit_engine(nnet_xy)))), + 24 + ) nnet_form_xy <- predict(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(names(nnet_form_pred), paste0(".pred_", c("V1", "V2", "V3"))) - keras::backend()$clear_session() }) @@ -262,12 +308,17 @@ test_that('all keras activation functions', { test_act <- function(fn) { set.seed(1) try( - mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2, - activation = !!fn) |> + mlp( + mode = "classification", + hidden_units = 2, + penalty = 0.01, + epochs = 2, + activation = !!fn + ) |> set_engine("keras", verbose = 0) |> parsnip::fit(Class ~ A + B, data = modeldata::two_class_dat), - silent = TRUE) - + silent = TRUE + ) } test_act_sshhh <- purrr::quietly(test_act) @@ -280,8 +331,13 @@ test_that('all keras activation functions', { expect_snapshot( error = TRUE, - mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2, - activation = "invalid") |> + mlp( + mode = "classification", + hidden_units = 2, + penalty = 0.01, + epochs = 2, + activation = "invalid" + ) |> set_engine("keras", verbose = 0) |> parsnip::fit(Class ~ A + B, data = modeldata::two_class_dat) ) diff --git a/tests/testthat/test-mlp_nnet.R b/tests/testthat/test-mlp_nnet.R index 22fbcb359..51e53ece0 100644 --- a/tests/testthat/test-mlp_nnet.R +++ b/tests/testthat/test-mlp_nnet.R @@ -11,7 +11,6 @@ hpc_nnet <- # ------------------------------------------------------------------------------ test_that('nnet execution, classification', { - skip_if_not_installed("nnet") expect_no_condition( @@ -45,7 +44,6 @@ test_that('nnet execution, classification', { test_that('nnet classification prediction', { - skip_if_not_installed("nnet") xy_fit <- fit_xy( @@ -55,9 +53,16 @@ test_that('nnet classification prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), newdata = hpc[1:8, num_pred], type = "class") + xy_pred <- predict( + extract_fit_engine(xy_fit), + newdata = hpc[1:8, num_pred], + type = "class" + ) xy_pred <- factor(xy_pred, levels = levels(hpc$class)) - expect_equal(xy_pred, predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class) + expect_equal( + xy_pred, + predict(xy_fit, new_data = hpc[1:8, num_pred], type = "class")$.pred_class + ) form_fit <- fit( hpc_nnet, @@ -66,9 +71,16 @@ test_that('nnet classification prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), newdata = hpc[1:8, num_pred], type = "class") + form_pred <- predict( + extract_fit_engine(form_fit), + newdata = hpc[1:8, num_pred], + type = "class" + ) form_pred <- factor(form_pred, levels = levels(hpc$class)) - expect_equal(form_pred, predict(form_fit, new_data = hpc[1:8, num_pred])$.pred_class) + expect_equal( + form_pred, + predict(form_fit, new_data = hpc[1:8, num_pred])$.pred_class + ) }) @@ -89,9 +101,7 @@ bad_rf_reg <- # ------------------------------------------------------------------------------ - test_that('nnet execution, regression', { - skip_if_not_installed("nnet") expect_no_condition( @@ -114,9 +124,7 @@ test_that('nnet execution, regression', { }) - test_that('nnet regression prediction', { - skip_if_not_installed("nnet") xy_fit <- fit_xy( @@ -126,7 +134,7 @@ test_that('nnet regression prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), newdata = mtcars[1:8, -1])[,1] + xy_pred <- predict(extract_fit_engine(xy_fit), newdata = mtcars[1:8, -1])[, 1] xy_pred <- unname(xy_pred) expect_equal(xy_pred, predict(xy_fit, new_data = mtcars[1:8, -1])$.pred) @@ -137,7 +145,10 @@ test_that('nnet regression prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), newdata = mtcars[1:8, -1])[,1] + form_pred <- predict( + extract_fit_engine(form_fit), + newdata = mtcars[1:8, -1] + )[, 1] form_pred <- unname(form_pred) expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred) }) @@ -147,7 +158,6 @@ test_that('nnet regression prediction', { nn_dat <- read.csv("nnet_test.txt") test_that('multivariate nnet formula', { - skip_if_not_installed("nnet") nnet_form <- @@ -155,11 +165,11 @@ test_that('multivariate nnet formula', { mode = "regression", hidden_units = 3, penalty = 0.01 - ) |> + ) |> set_engine("nnet") |> parsnip::fit( cbind(V1, V2, V3) ~ ., - data = nn_dat[-(1:5),] + data = nn_dat[-(1:5), ] ) expect_false(has_multi_predict(nnet_form)) @@ -178,12 +188,9 @@ test_that('multivariate nnet formula', { set_engine("nnet") |> parsnip::fit_xy( x = nn_dat[-(1:5), -(1:3)], - y = nn_dat[-(1:5), 1:3 ] + y = nn_dat[-(1:5), 1:3] ) expect_equal(length(extract_fit_engine(nnet_xy)$wts), 24) nnet_form_xy <- predict(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(names(nnet_form_xy), paste0(".pred_", c("V1", "V2", "V3"))) }) - - - diff --git a/tests/testthat/test-model_basics.R b/tests/testthat/test-model_basics.R index abd70c00b..6138e8477 100644 --- a/tests/testthat/test-model_basics.R +++ b/tests/testthat/test-model_basics.R @@ -97,6 +97,4 @@ test_that('basic object classes and print methods', { expect_snapshot(print(svm_rbf())) expect_true(inherits(svm_rbf(engine = 'kernlab'), 'svm_rbf')) expect_true(inherits(svm_rbf(engine = 'liquidSVM'), 'svm_rbf')) - }) - diff --git a/tests/testthat/test-multinom_reg.R b/tests/testthat/test-multinom_reg.R index dbb1ca130..42cfb72ec 100644 --- a/tests/testthat/test-multinom_reg.R +++ b/tests/testthat/test-multinom_reg.R @@ -13,11 +13,15 @@ test_that('updating', { test_that('bad input', { expect_snapshot(error = TRUE, multinom_reg(mode = "regression")) - expect_snapshot(error = TRUE, translate(multinom_reg(penalty = 0.1) |> set_engine("wat?"))) + expect_snapshot( + error = TRUE, + translate(multinom_reg(penalty = 0.1) |> set_engine("wat?")) + ) expect_snapshot(error = TRUE, multinom_reg(penalty = 0.1) |> set_engine()) expect_warning( translate( - multinom_reg(penalty = 0.1) |> set_engine("glmnet", x = hpc[,1:3], y = hpc$class) + multinom_reg(penalty = 0.1) |> + set_engine("glmnet", x = hpc[, 1:3], y = hpc$class) ), class = "parsnip_protected_arg_warning" ) @@ -49,7 +53,6 @@ test_that('check_args() works', { # ------------------------------------------------------------------------------ test_that("tunables", { - expect_snapshot( multinom_reg() |> tunable() @@ -78,5 +81,4 @@ test_that("tunables", { set_engine("keras") |> tunable() ) - }) diff --git a/tests/testthat/test-multinom_reg_keras.R b/tests/testthat/test-multinom_reg_keras.R index 25c51043b..03b96f2ac 100644 --- a/tests/testthat/test-multinom_reg_keras.R +++ b/tests/testthat/test-multinom_reg_keras.R @@ -6,7 +6,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ set.seed(352) -dat <- hpc[order(runif(150)),] +dat <- hpc[order(runif(150)), ] tr_dat <- dat[1:140, ] te_dat <- dat[141:150, ] @@ -86,7 +86,6 @@ test_that('model fitting', { control = ctrl ) ) - }) @@ -110,7 +109,9 @@ test_that('classification prediction', { predict(extract_fit_engine(lr_fit), as.matrix(te_dat[, -5])) keras_pred <- tibble::tibble(.pred_class = apply(keras_raw, 1, which.max)) |> - dplyr::mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl)) + dplyr::mutate( + .pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl) + ) parsnip_pred <- predict(lr_fit, te_dat[, -5]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) @@ -128,10 +129,11 @@ test_that('classification prediction', { predict(extract_fit_engine(plrfit), as.matrix(te_dat[, -5])) keras_pred <- tibble::tibble(.pred_class = apply(keras_raw, 1, which.max)) |> - dplyr::mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl)) + dplyr::mutate( + .pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl) + ) parsnip_pred <- predict(plrfit, te_dat[, -5]) expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - }) @@ -176,7 +178,4 @@ test_that('classification probabilities', { setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob") expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) - }) - - diff --git a/tests/testthat/test-multinom_reg_nnet.R b/tests/testthat/test-multinom_reg_nnet.R index 5d953d9f5..19e7966a2 100644 --- a/tests/testthat/test-multinom_reg_nnet.R +++ b/tests/testthat/test-multinom_reg_nnet.R @@ -5,7 +5,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ set.seed(352) -dat <- hpc[order(runif(150)),] +dat <- hpc[order(runif(150)), ] tr_dat <- dat[1:140, ] te_dat <- dat[141:150, ] @@ -55,7 +55,6 @@ test_that('model fitting', { control = ctrl ) ) - }) @@ -76,7 +75,6 @@ test_that('classification prediction', { parsnip_pred <- predict(lr_fit, te_dat[, -5]) expect_equal(nnet_pred, parsnip_pred$.pred_class) - }) @@ -93,13 +91,16 @@ test_that('classification probabilities', { ) nnet_pred <- - predict(extract_fit_engine(lr_fit), as.matrix(te_dat[, -5]), type = "prob") |> + predict( + extract_fit_engine(lr_fit), + as.matrix(te_dat[, -5]), + type = "prob" + ) |> as_tibble(.name_repair = "minimal") |> setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob") expect_equal(as.data.frame(nnet_pred), as.data.frame(parsnip_pred)) - }) test_that('prob prediction with 1 row', { @@ -116,7 +117,11 @@ test_that('prob prediction with 1 row', { ) nnet_pred <- - predict(extract_fit_engine(lr_fit), as.matrix(te_dat[1, -5]), type = "prob") |> + predict( + extract_fit_engine(lr_fit), + as.matrix(te_dat[1, -5]), + type = "prob" + ) |> as.matrix() |> t() |> tibble::as_tibble(.name_repair = "minimal") |> @@ -127,5 +132,3 @@ test_that('prob prediction with 1 row', { expect_equal(nnet_pred, parsnip_pred) expect_identical(nrow(parsnip_pred), 1L) }) - - diff --git a/tests/testthat/test-naive_Bayes.R b/tests/testthat/test-naive_Bayes.R index 85639dc94..d95d58a5e 100644 --- a/tests/testthat/test-naive_Bayes.R +++ b/tests/testthat/test-naive_Bayes.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/discrim expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-nearest_neighbor_kknn.R b/tests/testthat/test-nearest_neighbor_kknn.R index 377ddc597..1f1b99f23 100644 --- a/tests/testthat/test-nearest_neighbor_kknn.R +++ b/tests/testthat/test-nearest_neighbor_kknn.R @@ -5,15 +5,16 @@ hpc <- hpc_data[1:150, c(2:5, 8)] num_pred <- c("compounds", "iterations", "num_pending") hpc_bad_form <- as.formula(class ~ term) -hpc_basic <- nearest_neighbor(mode = "classification", - neighbors = 8, - weight_func = "triangular") |> +hpc_basic <- nearest_neighbor( + mode = "classification", + neighbors = 8, + weight_func = "triangular" +) |> set_engine("kknn") # ------------------------------------------------------------------------------ test_that('kknn execution', { - skip_if_not_installed("kknn") library(kknn) @@ -40,11 +41,9 @@ test_that('kknn execution', { control = ctrl ) ) - }) test_that('kknn prediction', { - skip_if_not_installed("kknn") library(kknn) @@ -61,7 +60,10 @@ test_that('kknn prediction', { newdata = hpc[1:5, num_pred] ) - expect_equal(tibble(.pred_class = uni_pred), predict(res_xy, hpc[1:5, num_pred])) + expect_equal( + tibble(.pred_class = uni_pred), + predict(res_xy, hpc[1:5, num_pred]) + ) # nominal res_xy_nom <- fit_xy( @@ -78,7 +80,11 @@ test_that('kknn prediction', { expect_equal( uni_pred_nom, - predict(res_xy_nom, hpc[1:5, c("input_fields", "iterations")], type = "class")$.pred_class + predict( + res_xy_nom, + hpc[1:5, c("input_fields", "iterations")], + type = "class" + )$.pred_class ) # continuous - formula interface @@ -91,15 +97,17 @@ test_that('kknn prediction', { form_pred <- predict( extract_fit_engine(res_form), - newdata = hpc[1:5,] + newdata = hpc[1:5, ] ) - expect_equal(form_pred, predict(res_form, hpc[1:5, c("compounds", "class")])$.pred) + expect_equal( + form_pred, + predict(res_form, hpc[1:5, c("compounds", "class")])$.pred + ) }) test_that('kknn multi-predict', { - skip_if_not_installed("kknn") library(kknn) @@ -115,8 +123,10 @@ test_that('kknn multi-predict', { ) pred_multi <- multi_predict(res_xy, hpc[hpc_te, num_pred], neighbors = k_vals) - expect_equal(pred_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), - length(hpc_te) * length(k_vals)) + expect_equal( + pred_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), + length(hpc_te) * length(k_vals) + ) expect_equal(pred_multi |> nrow(), length(hpc_te)) pred_uni <- predict(res_xy, hpc[hpc_te, num_pred]) @@ -129,11 +139,16 @@ test_that('kknn multi-predict', { dplyr::select(.pred_class) expect_equal(pred_uni, pred_uni_obs) - - prob_multi <- multi_predict(res_xy, hpc[hpc_te, num_pred], - neighbors = k_vals, type = "prob") - expect_equal(prob_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), - length(hpc_te) * length(k_vals)) + prob_multi <- multi_predict( + res_xy, + hpc[hpc_te, num_pred], + neighbors = k_vals, + type = "prob" + ) + expect_equal( + prob_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), + length(hpc_te) * length(k_vals) + ) expect_equal(prob_multi |> nrow(), length(hpc_te)) prob_uni <- predict(res_xy, hpc[hpc_te, num_pred], type = "prob") @@ -156,12 +171,15 @@ test_that('kknn multi-predict', { nearest_neighbor(mode = "regression", neighbors = 3) |> set_engine("kknn"), control = ctrl, - mpg ~ ., data = mtcars[-cars_te, ] + mpg ~ ., + data = mtcars[-cars_te, ] ) pred_multi <- multi_predict(res_xy, mtcars[cars_te, -1], neighbors = k_vals) - expect_equal(pred_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), - length(cars_te) * length(k_vals)) + expect_equal( + pred_multi |> tidyr::unnest(cols = c(.pred)) |> nrow(), + length(cars_te) * length(k_vals) + ) expect_equal(pred_multi |> nrow(), length(cars_te)) pred_uni <- predict(res_xy, mtcars[cars_te, -1]) @@ -191,15 +209,13 @@ test_that('argument checks for data dimensions', { set_mode("regression") expect_snapshot( - f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) + f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) ) expect_snapshot( xy_fit <- spec |> fit_xy(x = penguins[, -6], y = penguins$body_mass_g) ) - expect_equal(extract_fit_engine(f_fit)$best.parameters$k, nrow(penguins) - 5) + expect_equal(extract_fit_engine(f_fit)$best.parameters$k, nrow(penguins) - 5) expect_equal(extract_fit_engine(xy_fit)$best.parameters$k, nrow(penguins) - 5) - }) - diff --git a/tests/testthat/test-nullmodel.R b/tests/testthat/test-nullmodel.R index 47e31bbf1..9d20f0a06 100644 --- a/tests/testthat/test-nullmodel.R +++ b/tests/testthat/test-nullmodel.R @@ -3,11 +3,15 @@ skip_if_not_installed("modeldata") hpc <- hpc_data[1:150, c(2:5, 8)] |> as.data.frame() test_that('bad input', { - expect_snapshot(error = TRUE, translate(null_model(mode = "regression") |> set_engine())) + expect_snapshot( + error = TRUE, + translate(null_model(mode = "regression") |> set_engine()) + ) expect_snapshot(error = TRUE, translate(null_model() |> set_engine("wat?"))) expect_warning( translate( - null_model(mode = "regression") |> set_engine("parsnip", x = hpc[,1:3], y = hpc$class) + null_model(mode = "regression") |> + set_engine("parsnip", x = hpc[, 1:3], y = hpc$class) ), class = "parsnip_protected_arg_warning" ) @@ -21,7 +25,6 @@ hpc_bad_form <- as.formula(class ~ term) # ------------------------------------------------------------------------------ test_that('nullmodel execution', { - expect_no_condition( res <- fit( null_model(mode = "regression") |> set_engine("parsnip"), @@ -76,15 +79,12 @@ test_that('nullmodel execution', { data = hpc ) ) - }) test_that('nullmodel prediction', { - uni_pred <- tibble(.pred = rep(30.1, 5)) inl_pred <- rep(30.1, 5) - mw_pred <- tibble(gear = rep(3.6875, 5), - carb = rep(2.8125, 5)) + mw_pred <- tibble(gear = rep(3.6875, 5), carb = rep(2.8125, 5)) res_xy <- fit_xy( null_model(mode = "regression") |> set_engine("parsnip"), @@ -92,18 +92,18 @@ test_that('nullmodel prediction', { y = hpc$num_pending ) - expect_equal(uni_pred, - predict(res_xy, new_data = hpc[1:5, num_pred]), - tolerance = .01) + expect_equal( + uni_pred, + predict(res_xy, new_data = hpc[1:5, num_pred]), + tolerance = .01 + ) res_form <- fit( null_model(mode = "regression") |> set_engine("parsnip"), num_pending ~ log(compounds) + class, data = hpc ) - expect_equal(inl_pred, - predict(res_form, hpc[1:5, ])$.pred, - tolerance = .01) + expect_equal(inl_pred, predict(res_form, hpc[1:5, ])$.pred, tolerance = .01) # Multivariate y res <- fit( @@ -121,7 +121,6 @@ test_that('nullmodel prediction', { # ------------------------------------------------------------------------------ test_that('classification', { - expect_no_condition( null_model <- null_model(mode = "classification") |> set_engine("parsnip") |> @@ -182,7 +181,11 @@ test_that("null_model works with sparse matrix data - classification", { hotel_data <- sparse_hotel_rates() # Create a factor outcome for classification - y_class <- factor(ifelse(hotel_data[, 1] > median(hotel_data[, 1]), "high", "low")) + y_class <- factor(ifelse( + hotel_data[, 1] > median(hotel_data[, 1]), + "high", + "low" + )) spec <- null_model(mode = "classification") |> set_engine("parsnip") diff --git a/tests/testthat/test-packages.R b/tests/testthat/test-packages.R index eb83d20b8..f0c5ecd8c 100644 --- a/tests/testthat/test-packages.R +++ b/tests/testthat/test-packages.R @@ -1,4 +1,3 @@ - load(test_path("mars_model.RData")) # ------------------------------------------------------------------------------ @@ -38,13 +37,10 @@ test_that('missing packages', { if (has_earth) { expect_no_condition(predict(mars_model, mtcars[1:3, -1])) - } else { expect_snapshot(error = TRUE, predict(mars_model, mtcars[1:3, -1])) expect_true(any(names(sessionInfo()$loadedOnly) == "earth")) } mars_model$spec$method$libs <- "rootveggie" expect_snapshot(error = TRUE, predict(mars_model, mtcars[1:3, -1])) - }) - diff --git a/tests/testthat/test-partykit.R b/tests/testthat/test-partykit.R index 6dcfdfc07..26615e6ed 100644 --- a/tests/testthat/test-partykit.R +++ b/tests/testthat/test-partykit.R @@ -1,5 +1,3 @@ - - test_that('fit ctree models', { skip_if_not_installed("modeldata") skip_if_not_installed("partykit") @@ -12,24 +10,36 @@ test_that('fit ctree models', { fit_1 <- ctree_train(ridership ~ ., data = Chicago[, 1:20]) ) expect_no_condition( - fit_2 <- ctree_train(ridership ~ ., data = Chicago[, 1:20], - mincriterion = 1/2, maxdepth = 2) + fit_2 <- ctree_train( + ridership ~ ., + data = Chicago[, 1:20], + mincriterion = 1 / 2, + maxdepth = 2 + ) ) - expect_equal(fit_2$info$control$logmincriterion, log(1/2)) + expect_equal(fit_2$info$control$logmincriterion, log(1 / 2)) expect_equal(fit_2$info$control$maxdepth, 2) expect_no_condition( - fit_3 <- ctree_train(ridership ~ ., data = Chicago[, 1:20], - mincriterion = 1/2, maxdepth = 2, - weights = 1:nrow(Chicago)) + fit_3 <- ctree_train( + ridership ~ ., + data = Chicago[, 1:20], + mincriterion = 1 / 2, + maxdepth = 2, + weights = 1:nrow(Chicago) + ) ) expect_false(isTRUE(all.equal(fit_2$fitted, fit_3$fitted))) expect_no_condition( fit_4 <- ctree_train(Class ~ ., data = ad_data) ) expect_snapshot_error( - ctree_train(ridership ~ ., data = Chicago[, 1:20], - mincriterion = 1/2, maxdepth = 2, - weights = runif(nrow(Chicago))) + ctree_train( + ridership ~ ., + data = Chicago[, 1:20], + mincriterion = 1 / 2, + maxdepth = 2, + weights = runif(nrow(Chicago)) + ) ) }) @@ -46,16 +56,28 @@ test_that('fit cforest models', { ) expect_equal(length(fit_1$nodes), 2) expect_no_condition( - fit_2 <- cforest_train(ridership ~ ., data = Chicago[, 1:5], ntree = 2, - mincriterion = 1/2, maxdepth = 2, mtry = 4) + fit_2 <- cforest_train( + ridership ~ ., + data = Chicago[, 1:5], + ntree = 2, + mincriterion = 1 / 2, + maxdepth = 2, + mtry = 4 + ) ) - expect_equal(fit_2$info$control$logmincriterion, log(1/2)) + expect_equal(fit_2$info$control$logmincriterion, log(1 / 2)) expect_equal(fit_2$info$control$maxdepth, 2) expect_equal(fit_2$info$control$mtry, 4) expect_no_condition( - fit_3 <- cforest_train(ridership ~ ., data = Chicago[, 1:5], ntree = 2, - mincriterion = 1/2, maxdepth = 2, mtry = 4, - weights = 1:nrow(Chicago)) + fit_3 <- cforest_train( + ridership ~ ., + data = Chicago[, 1:5], + ntree = 2, + mincriterion = 1 / 2, + maxdepth = 2, + mtry = 4, + weights = 1:nrow(Chicago) + ) ) expect_false(isTRUE(all.equal(fit_2$fitted, fit_3$fitted))) expect_no_condition( diff --git a/tests/testthat/test-pls.R b/tests/testthat/test-pls.R index 6d458c703..a885cd595 100644 --- a/tests/testthat/test-pls.R +++ b/tests/testthat/test-pls.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/plsmod expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-poisson_reg.R b/tests/testthat/test-poisson_reg.R index 5bf960840..4b270842a 100644 --- a/tests/testthat/test-poisson_reg.R +++ b/tests/testthat/test-poisson_reg.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/poissonreg expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index d19f2c040..26f5c2060 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -4,13 +4,12 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ - lm_fit <- linear_reg(mode = "regression") |> set_engine("lm") |> fit(compounds ~ ., data = hpc) -class_dat <- airquality[complete.cases(airquality),] +class_dat <- airquality[complete.cases(airquality), ] class_dat$Ozone <- factor(ifelse(class_dat$Ozone >= 31, "high", "low")) lr_fit <- @@ -18,8 +17,12 @@ lr_fit <- set_engine("glm") |> fit(Ozone ~ ., data = class_dat) -class_dat2 <- airquality[complete.cases(airquality),] -class_dat2$Ozone <- factor(ifelse(class_dat2$Ozone >= 31, "high+values", "2low")) +class_dat2 <- airquality[complete.cases(airquality), ] +class_dat2$Ozone <- factor(ifelse( + class_dat2$Ozone >= 31, + "high+values", + "2low" +)) lr_fit_2 <- logistic_reg() |> @@ -29,20 +32,38 @@ lr_fit_2 <- # ------------------------------------------------------------------------------ test_that('regression predictions', { - expect_true(is_tibble(predict(lm_fit, new_data = hpc[1:5,-1]))) - expect_true(is.vector(parsnip:::predict_numeric.model_fit(lm_fit, new_data = hpc[1:5,-1]))) - expect_equal(names(predict(lm_fit, new_data = hpc[1:5,-1])), ".pred") + expect_true(is_tibble(predict(lm_fit, new_data = hpc[1:5, -1]))) + expect_true(is.vector(parsnip:::predict_numeric.model_fit( + lm_fit, + new_data = hpc[1:5, -1] + ))) + expect_equal(names(predict(lm_fit, new_data = hpc[1:5, -1])), ".pred") }) test_that('classification predictions', { - expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1]))) - expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") - - expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob"))) - expect_true(is_tibble(parsnip:::predict_classprob.model_fit(lr_fit, new_data = class_dat[1:5,-1]))) - expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob")), - c(".pred_high", ".pred_low")) + expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5, -1]))) + expect_true(is.factor(parsnip:::predict_class.model_fit( + lr_fit, + new_data = class_dat[1:5, -1] + ))) + expect_equal( + names(predict(lr_fit, new_data = class_dat[1:5, -1])), + ".pred_class" + ) + + expect_true(is_tibble(predict( + lr_fit, + new_data = class_dat[1:5, -1], + type = "prob" + ))) + expect_true(is_tibble(parsnip:::predict_classprob.model_fit( + lr_fit, + new_data = class_dat[1:5, -1] + ))) + expect_equal( + names(predict(lr_fit, new_data = class_dat[1:5, -1], type = "prob")), + c(".pred_high", ".pred_low") + ) }) @@ -54,16 +75,18 @@ test_that('ordinal classification predictions', { dat_tr <- modeldata::sim_multinomial( 200, - ~ -0.5 + 0.6 * abs(A), - ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), - ~ -0.6 * A + 0.50 * B - A * B) |> + ~ -0.5 + 0.6 * abs(A), + ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, -2), + ~ -0.6 * A + 0.50 * B - A * B + ) |> dplyr::mutate(class = as.ordered(class)) dat_te <- modeldata::sim_multinomial( 5, - ~ -0.5 + 0.6 * abs(A), - ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), - ~ -0.6 * A + 0.50 * B - A * B) |> + ~ -0.5 + 0.6 * abs(A), + ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, -2), + ~ -0.6 * A + 0.50 * B - A * B + ) |> dplyr::mutate(class = as.ordered(class)) ### @@ -90,16 +113,36 @@ test_that('ordinal classification predictions', { test_that('non-standard levels', { - expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1]))) - expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") - - expect_true(is_tibble(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob"))) - expect_true(is_tibble(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1]))) - expect_equal(names(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")), - c(".pred_2low", ".pred_high+values")) - expect_equal(names(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1])), - c("2low", "high+values")) + expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5, -1]))) + expect_true(is.factor(parsnip:::predict_class.model_fit( + lr_fit, + new_data = class_dat[1:5, -1] + ))) + expect_equal( + names(predict(lr_fit, new_data = class_dat[1:5, -1])), + ".pred_class" + ) + + expect_true(is_tibble(predict( + lr_fit_2, + new_data = class_dat2[1:5, -1], + type = "prob" + ))) + expect_true(is_tibble(parsnip:::predict_classprob.model_fit( + lr_fit_2, + new_data = class_dat2[1:5, -1] + ))) + expect_equal( + names(predict(lr_fit_2, new_data = class_dat2[1:5, -1], type = "prob")), + c(".pred_2low", ".pred_high+values") + ) + expect_equal( + names(parsnip:::predict_classprob.model_fit( + lr_fit_2, + new_data = class_dat2[1:5, -1] + )), + c("2low", "high+values") + ) }) test_that('predict(type = "prob") with level "class" (see #720)', { @@ -133,15 +176,16 @@ test_that('non-factor classification', { error = TRUE, logistic_reg() |> set_engine("glm") |> - fit(class ~ ., - data = hpc |> dplyr::mutate(class = class == "VF")) + fit(class ~ ., data = hpc |> dplyr::mutate(class = class == "VF")) ) expect_snapshot( error = TRUE, logistic_reg() |> set_engine("glm") |> - fit(class ~ ., - data = hpc |> dplyr::mutate(class = ifelse(class == "VF", 1, 0))) + fit( + class ~ ., + data = hpc |> dplyr::mutate(class = ifelse(class == "VF", 1, 0)) + ) ) skip_if_not_installed("glmnet") @@ -149,8 +193,7 @@ test_that('non-factor classification', { error = TRUE, multinom_reg() |> set_engine("glmnet") |> - fit(class ~ ., - data = hpc |> dplyr::mutate(class = as.character(class))) + fit(class ~ ., data = hpc |> dplyr::mutate(class = as.character(class))) ) }) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R index cea7846b1..9eb6220c7 100644 --- a/tests/testthat/test-print.R +++ b/tests/testthat/test-print.R @@ -19,8 +19,10 @@ test_that("`print_model_spec()` handles args correctly", { test_that("`get_model_desc()` retrieves/creates model description well", { expect_equal(get_model_desc("linear_reg"), "Linear Regression") expect_equal(get_model_desc("boost_tree"), "Boosted Tree") - expect_equal(get_model_desc("boost_tree"), - model_descs$desc[model_descs$cls == "boost_tree"]) + expect_equal( + get_model_desc("boost_tree"), + model_descs$desc[model_descs$cls == "boost_tree"] + ) expect_equal(get_model_desc("goofy new class"), "goofy new class") expect_equal(get_model_desc("goofy_new_class"), "goofy new class") diff --git a/tests/testthat/test-proportional_hazards.R b/tests/testthat/test-proportional_hazards.R index a60657639..b686faebf 100644 --- a/tests/testthat/test-proportional_hazards.R +++ b/tests/testthat/test-proportional_hazards.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/censored expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-rand_forest_ranger.R b/tests/testthat/test-rand_forest_ranger.R index 158d0cf98..8974a9bc5 100644 --- a/tests/testthat/test-rand_forest_ranger.R +++ b/tests/testthat/test-rand_forest_ranger.R @@ -8,15 +8,17 @@ lending_club <- head(lending_club, 200) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") lc_basic <- rand_forest(mode = "classification") |> set_engine("ranger") -lc_ranger <- rand_forest(mode = "classification") |> set_engine("ranger", seed = 144) +lc_ranger <- rand_forest(mode = "classification") |> + set_engine("ranger", seed = 144) -bad_ranger_cls <- rand_forest(mode = "classification") |> set_engine("ranger", replace = "bad") -bad_rf_cls <- rand_forest(mode = "classification") |> set_engine("ranger", sampsize = -10) +bad_ranger_cls <- rand_forest(mode = "classification") |> + set_engine("ranger", replace = "bad") +bad_rf_cls <- rand_forest(mode = "classification") |> + set_engine("ranger", sampsize = -10) # ------------------------------------------------------------------------------ test_that('ranger classification execution', { - skip_if_not_installed("ranger") expect_no_condition( @@ -65,11 +67,9 @@ test_that('ranger classification execution', { y = lending_club$Class ) expect_true(inherits(extract_fit_engine(ranger_xy_catch), "try-error")) - }) test_that('ranger classification prediction', { - skip_if_not_installed("ranger") xy_fit <- fit_xy( @@ -79,12 +79,19 @@ test_that('ranger classification prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), data = lending_club[1:6, num_pred])$prediction + xy_pred <- predict( + extract_fit_engine(xy_fit), + data = lending_club[1:6, num_pred] + )$prediction xy_pred <- colnames(xy_pred)[apply(xy_pred, 1, which.max)] xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) expect_equal( xy_pred, - predict(xy_fit, new_data = lending_club[1:6, num_pred], type = "class")$.pred_class + predict( + xy_fit, + new_data = lending_club[1:6, num_pred], + type = "class" + )$.pred_class ) form_fit <- fit( @@ -95,30 +102,39 @@ test_that('ranger classification prediction', { control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), data = lending_club[1:6, c("funded_amnt", "int_rate")])$prediction + form_pred <- predict( + extract_fit_engine(form_fit), + data = lending_club[1:6, c("funded_amnt", "int_rate")] + )$prediction form_pred <- colnames(form_pred)[apply(form_pred, 1, which.max)] form_pred <- factor(form_pred, levels = levels(lending_club$Class)) expect_equal( form_pred, - predict(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])$.pred_class + predict( + form_fit, + new_data = lending_club[1:6, c("funded_amnt", "int_rate")] + )$.pred_class ) - }) test_that('ranger classification probabilities', { - skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest() |> set_mode("classification") |> set_engine("ranger", seed = 3566), + rand_forest() |> + set_mode("classification") |> + set_engine("ranger", seed = 3566), x = lending_club[, num_pred], y = lending_club$Class, control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), data = lending_club[1:6, num_pred])$predictions + xy_pred <- predict( + extract_fit_engine(xy_fit), + data = lending_club[1:6, num_pred] + )$predictions xy_pred <- as_tibble(xy_pred) names(xy_pred) <- paste0(".pred_", names(xy_pred)) expect_equal( @@ -126,27 +142,41 @@ test_that('ranger classification probabilities', { predict(xy_fit, new_data = lending_club[1:6, num_pred], type = "prob") ) - one_row <- predict(xy_fit, new_data = lending_club[1, num_pred], type = "prob") - expect_equal(xy_pred[1,], one_row) + one_row <- predict( + xy_fit, + new_data = lending_club[1, num_pred], + type = "prob" + ) + expect_equal(xy_pred[1, ], one_row) form_fit <- fit( - rand_forest() |> set_mode("classification") |> set_engine("ranger", seed = 3566), + rand_forest() |> + set_mode("classification") |> + set_engine("ranger", seed = 3566), Class ~ funded_amnt + int_rate, data = lending_club, control = ctrl ) - form_pred <- predict(extract_fit_engine(form_fit), data = lending_club[1:6, c("funded_amnt", "int_rate")])$predictions + form_pred <- predict( + extract_fit_engine(form_fit), + data = lending_club[1:6, c("funded_amnt", "int_rate")] + )$predictions form_pred <- as_tibble(form_pred) names(form_pred) <- paste0(".pred_", names(form_pred)) expect_equal( form_pred, - predict(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")], type = "prob") + predict( + form_fit, + new_data = lending_club[1:6, c("funded_amnt", "int_rate")], + type = "prob" + ) ) no_prob_model <- fit_xy( - rand_forest(mode = "classification") |> set_engine("ranger", probability = FALSE), + rand_forest(mode = "classification") |> + set_engine("ranger", probability = FALSE), x = lending_club[, num_pred], y = lending_club$Class, control = ctrl @@ -154,7 +184,10 @@ test_that('ranger classification probabilities', { expect_snapshot( error = TRUE, - parsnip:::predict_classprob.model_fit(no_prob_model, new_data = lending_club[1:6, num_pred]) + parsnip:::predict_classprob.model_fit( + no_prob_model, + new_data = lending_club[1:6, num_pred] + ) ) }) @@ -164,13 +197,14 @@ num_pred <- names(mtcars)[3:6] car_basic <- rand_forest(mode = "regression") |> set_engine("ranger") -bad_ranger_reg <- rand_forest(mode = "regression") |> set_engine("ranger", replace = "bad") -bad_rf_reg <- rand_forest(mode = "regression") |> set_engine("ranger", sampsize = -10) +bad_ranger_reg <- rand_forest(mode = "regression") |> + set_engine("ranger", replace = "bad") +bad_rf_reg <- rand_forest(mode = "regression") |> + set_engine("ranger", sampsize = -10) # ------------------------------------------------------------------------------ test_that('ranger regression execution', { - skip_if_not_installed("ranger") expect_no_condition( @@ -191,7 +225,6 @@ test_that('ranger regression execution', { ) ) - ranger_form_catch <- fit( bad_ranger_reg, mpg ~ ., @@ -207,11 +240,9 @@ test_that('ranger regression execution', { y = mtcars$mpg ) expect_true(inherits(extract_fit_engine(ranger_xy_catch), "try-error")) - }) test_that('ranger regression prediction', { - skip_if_not_installed("ranger") xy_fit <- fit_xy( @@ -221,10 +252,12 @@ test_that('ranger regression prediction', { control = ctrl ) - xy_pred <- predict(extract_fit_engine(xy_fit), data = tail(mtcars[, -1]))$prediction + xy_pred <- predict( + extract_fit_engine(xy_fit), + data = tail(mtcars[, -1]) + )$prediction expect_equal(xy_pred, predict(xy_fit, new_data = tail(mtcars[, -1]))$.pred) - }) @@ -244,30 +277,38 @@ test_that('ranger regression intervals', { control = ctrl ) - rgr_pred <- predict(extract_fit_engine(xy_fit), data = head(ames_x, 3))$predictions + rgr_pred <- predict( + extract_fit_engine(xy_fit), + data = head(ames_x, 3) + )$predictions expect_snapshot( rgr_se <- - predict(extract_fit_engine(xy_fit), data = head(ames_x, 3), type = "se")$se + predict( + extract_fit_engine(xy_fit), + data = head(ames_x, 3), + type = "se" + )$se ) rgr_lower <- rgr_pred - qnorm(0.035, lower.tail = FALSE) * rgr_se rgr_upper <- rgr_pred + qnorm(0.035, lower.tail = FALSE) * rgr_se expect_snapshot( parsnip_int <- - predict(xy_fit, new_data = head(ames_x, 3), - type = "conf_int", std_error = TRUE, level = 0.93 + predict( + xy_fit, + new_data = head(ames_x, 3), + type = "conf_int", + std_error = TRUE, + level = 0.93 ) ) expect_equal(rgr_lower, parsnip_int$.pred_lower, ignore_formula_env = TRUE) expect_equal(rgr_upper, parsnip_int$.pred_upper) expect_equal(rgr_se, parsnip_int$.std_error) - }) - test_that('additional descriptor tests', { - skip_if_not_installed("ranger") descr_xy <- fit_xy( @@ -282,7 +323,8 @@ test_that('additional descriptor tests', { descr_f <- fit( rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) |> set_engine("ranger"), - mpg ~ ., data = mtcars, + mpg ~ ., + data = mtcars, control = ctrl ) expect_equal(extract_fit_engine(descr_f)$mtry, 4) @@ -299,7 +341,8 @@ test_that('additional descriptor tests', { descr_f <- fit( rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) |> set_engine("ranger"), - mpg ~ ., data = mtcars, + mpg ~ ., + data = mtcars, control = ctrl ) expect_equal(extract_fit_engine(descr_f)$mtry, 4) @@ -316,18 +359,25 @@ test_that('additional descriptor tests', { control = ctrl ) expect_equal(extract_fit_engine(descr_other_xy)$mtry, 2) - expect_equal(extract_fit_engine(descr_other_xy)$call$class.weights, exp_wts, - ignore_formula_env = TRUE) + expect_equal( + extract_fit_engine(descr_other_xy)$call$class.weights, + exp_wts, + ignore_formula_env = TRUE + ) descr_other_f <- fit( rand_forest(mode = "classification", mtry = 2) |> set_engine("ranger", class.weights = c(min(.lvls()), 20, 10, 1)), - class ~ ., data = hpc, + class ~ ., + data = hpc, control = ctrl ) expect_equal(extract_fit_engine(descr_other_f)$mtry, 2) - expect_equal(extract_fit_engine(descr_other_f)$call$class.weights, exp_wts, - ignore_formula_env = TRUE) + expect_equal( + extract_fit_engine(descr_other_f)$call$class.weights, + exp_wts, + ignore_formula_env = TRUE + ) descr_other_xy <- fit_xy( rand_forest(mode = "classification", mtry = 2) |> @@ -337,27 +387,35 @@ test_that('additional descriptor tests', { control = ctrl ) expect_equal(extract_fit_engine(descr_other_xy)$mtry, 2) - expect_equal(extract_fit_engine(descr_other_xy)$call$class.weights, exp_wts, - ignore_formula_env = TRUE) + expect_equal( + extract_fit_engine(descr_other_xy)$call$class.weights, + exp_wts, + ignore_formula_env = TRUE + ) descr_other_f <- fit( rand_forest(mode = "classification", mtry = 2) |> set_engine("ranger", class.weights = c(min(.lvls()), 20, 10, 1)), - class ~ ., data = hpc, + class ~ ., + data = hpc, control = ctrl ) expect_equal(extract_fit_engine(descr_other_f)$mtry, 2) - expect_equal(extract_fit_engine(descr_other_f)$call$class.weights, exp_wts, - ignore_formula_env = TRUE) + expect_equal( + extract_fit_engine(descr_other_f)$call$class.weights, + exp_wts, + ignore_formula_env = TRUE + ) }) test_that('ranger classification prediction', { - skip_if_not_installed("ranger") xy_class_fit <- - rand_forest() |> set_mode("classification") |> set_engine("ranger") |> + rand_forest() |> + set_mode("classification") |> + set_engine("ranger") |> fit_xy( x = hpc[, 1:4], y = hpc$class, @@ -367,7 +425,10 @@ test_that('ranger classification prediction', { expect_false(has_multi_predict(xy_class_fit)) expect_equal(multi_predict_args(xy_class_fit), NA_character_) - xy_class_pred <- predict(extract_fit_engine(xy_class_fit), data = hpc[c(1, 51, 101), 1:4])$prediction + xy_class_pred <- predict( + extract_fit_engine(xy_class_fit), + data = hpc[c(1, 51, 101), 1:4] + )$prediction xy_class_pred <- colnames(xy_class_pred)[apply(xy_class_pred, 1, which.max)] xy_class_pred <- factor(xy_class_pred, levels = levels(hpc$class)) @@ -386,7 +447,10 @@ test_that('ranger classification prediction', { control = ctrl ) - xy_prob_pred <- predict(extract_fit_engine(xy_prob_fit), data = hpc[c(1, 51, 101), 1:4])$prediction + xy_prob_pred <- predict( + extract_fit_engine(xy_prob_fit), + data = hpc[c(1, 51, 101), 1:4] + )$prediction xy_prob_pred <- colnames(xy_prob_pred)[apply(xy_prob_pred, 1, which.max)] xy_prob_pred <- factor(xy_prob_pred, levels = levels(hpc$class)) @@ -395,7 +459,11 @@ test_that('ranger classification prediction', { predict(xy_prob_fit, new_data = hpc[c(1, 51, 101), 1:4])$.pred_class ) - xy_prob_prob <- predict(extract_fit_engine(xy_prob_fit), data = hpc[c(1, 51, 101), 1:4], type = "response") + xy_prob_prob <- predict( + extract_fit_engine(xy_prob_fit), + data = hpc[c(1, 51, 101), 1:4], + type = "response" + ) xy_prob_prob <- as_tibble(xy_prob_prob$prediction) names(xy_prob_prob) <- paste0(".pred_", names(xy_prob_prob)) expect_equal( @@ -406,7 +474,6 @@ test_that('ranger classification prediction', { test_that('ranger classification intervals', { - skip_if_not_installed("ranger") lc_fit <- fit( @@ -417,9 +484,16 @@ test_that('ranger classification intervals', { control = ctrl ) - rgr_pred <- predict(extract_fit_engine(lc_fit), data = tail(lending_club))$predictions + rgr_pred <- predict( + extract_fit_engine(lc_fit), + data = tail(lending_club) + )$predictions expect_snapshot( - rgr_se <- predict(extract_fit_engine(lc_fit), data = tail(lending_club), type = "se")$se + rgr_se <- predict( + extract_fit_engine(lc_fit), + data = tail(lending_club), + type = "se" + )$se ) rgr_lower <- rgr_pred - qnorm(0.035, lower.tail = FALSE) * rgr_se rgr_upper <- rgr_pred + qnorm(0.035, lower.tail = FALSE) * rgr_se @@ -428,8 +502,12 @@ test_that('ranger classification intervals', { expect_snapshot( parsnip_int <- - predict(lc_fit, new_data = tail(lending_club), - type = "conf_int", std_error = TRUE, level = 0.93 + predict( + lc_fit, + new_data = tail(lending_club), + type = "conf_int", + std_error = TRUE, + level = 0.93 ) ) expect_equal(rgr_lower[, "bad"], parsnip_int$.pred_lower_bad) @@ -438,11 +516,9 @@ test_that('ranger classification intervals', { expect_equal(rgr_upper[, "good"], parsnip_int$.pred_upper_good) expect_equal(rgr_se[, 1], parsnip_int$.std_error_bad) expect_equal(rgr_se[, 2], parsnip_int$.std_error_good) - }) - test_that('ranger and sparse matrices', { skip_if_not_installed("ranger") @@ -470,7 +546,6 @@ test_that('ranger and sparse matrices', { ## ----------------------------------------------------------------------------- test_that('argument checks for data dimensions', { - skip_if_not_installed("ranger") data(penguins, package = "modeldata") @@ -482,16 +557,14 @@ test_that('argument checks for data dimensions', { set_mode("regression") expect_snapshot( - f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) + f_fit <- spec |> fit(body_mass_g ~ ., data = penguins) ) expect_snapshot( xy_fit <- spec |> fit_xy(x = penguins[, -6], y = penguins$body_mass_g) ) - expect_equal(extract_fit_engine(f_fit)$mtry, 6) expect_equal(extract_fit_engine(f_fit)$min.node.size, nrow(penguins)) expect_equal(extract_fit_engine(xy_fit)$mtry, 6) expect_equal(extract_fit_engine(xy_fit)$min.node.size, nrow(penguins)) - }) diff --git a/tests/testthat/test-re_registration.R b/tests/testthat/test-re_registration.R index c00a9bf8e..0e9a542de 100644 --- a/tests/testthat/test-re_registration.R +++ b/tests/testthat/test-re_registration.R @@ -1,7 +1,6 @@ # For issue #653 we want to be able to re-run the registration code as # long as the information being registered is the same. - test_that('re-registration of mode', { old_val <- get_from_env("bart_modes") expect_no_condition(set_model_mode("bart", "classification")) @@ -115,12 +114,11 @@ test_that('re-registration of prediction information', { pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data), - type = "numeric" - ) + args = list( + obj = quote(object), + new_data = quote(new_data), + type = "numeric" + ) ) ) ) @@ -140,14 +138,12 @@ test_that('re-registration of prediction information', { pre = NULL, post = NULL, func = c(pkg = "parsnip", fun = "dbart_predict_calc"), - args = - list( - obj = quote(object), - new_data = quote(new_data), - type = "tuba" - ) + args = list( + obj = quote(object), + new_data = quote(new_data), + type = "tuba" + ) ) ) ) }) - diff --git a/tests/testthat/test-rule_fit.R b/tests/testthat/test-rule_fit.R index 0b3c40e3c..2cadaf221 100644 --- a/tests/testthat/test-rule_fit.R +++ b/tests/testthat/test-rule_fit.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/rules expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index f6cb6f4e9..c6e1f5141 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -50,7 +50,6 @@ test_that("sparse matrix can be passed to `fit() - supported", { error = TRUE, xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) ) - }) test_that("sparse matrix can be passed to `fit() - unsupported", { @@ -217,11 +216,11 @@ test_that("sparse data work with xgboost engine", { expect_no_error( xgb_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) - ) + ) - expect_no_error( - predict(xgb_fit, hotel_data) - ) + expect_no_error( + predict(xgb_fit, hotel_data) + ) hotel_data <- sparse_hotel_rates(tibble = TRUE) @@ -322,11 +321,19 @@ test_that("maybe_sparse_matrix() is used correctly", { ) expect_snapshot( error = TRUE, - fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1]) + fit_xy( + spec, + x = as.data.frame(mtcars)[, -1], + y = as.data.frame(mtcars)[, 1] + ) ) expect_snapshot( error = TRUE, - fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, 1]) + fit_xy( + spec, + x = tibble::as_tibble(mtcars)[, -1], + y = tibble::as_tibble(mtcars)[, 1] + ) ) }) diff --git a/tests/testthat/test-survival-censoring-weights.R b/tests/testthat/test-survival-censoring-weights.R index 770b5f6a1..1435a4ce2 100644 --- a/tests/testthat/test-survival-censoring-weights.R +++ b/tests/testthat/test-survival-censoring-weights.R @@ -15,7 +15,7 @@ test_that("probability truncation via trunc_probs()", { expect_equal(probs_trunc_04_na[2], data_derived_trunc) expect_equal(probs_trunc_04_na[3:6], probs[2:5]) - probs <- (1:200)/200 + probs <- (1:200) / 200 expect_identical( parsnip:::trunc_probs(probs, trunc = 0.01), probs diff --git a/tests/testthat/test-survival_reg.R b/tests/testthat/test-survival_reg.R index a60657639..b686faebf 100644 --- a/tests/testthat/test-survival_reg.R +++ b/tests/testthat/test-survival_reg.R @@ -3,4 +3,4 @@ test_that("testing", { # https://github.com/tidymodels/censored expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-svm_linear.R b/tests/testthat/test-svm_linear.R index 3dfb7ca3e..5923ab877 100644 --- a/tests/testthat/test-svm_linear.R +++ b/tests/testthat/test-svm_linear.R @@ -13,21 +13,34 @@ test_that('updating', { }) test_that('bad input', { - expect_snapshot(error = TRUE, translate(svm_linear(mode = "regression") |> set_engine( NULL))) + expect_snapshot( + error = TRUE, + translate(svm_linear(mode = "regression") |> set_engine(NULL)) + ) expect_snapshot(error = TRUE, svm_linear(mode = "reallyunknown")) - expect_snapshot(error = TRUE, translate(svm_linear(mode = "regression") |> set_engine("LiblineaR", type = 3))) - expect_snapshot(error = TRUE, translate(svm_linear(mode = "classification") |> set_engine("LiblineaR", type = 11))) + expect_snapshot( + error = TRUE, + translate( + svm_linear(mode = "regression") |> set_engine("LiblineaR", type = 3) + ) + ) + expect_snapshot( + error = TRUE, + translate( + svm_linear(mode = "classification") |> set_engine("LiblineaR", type = 11) + ) + ) }) # ------------------------------------------------------------------------------ reg_mod <- - svm_linear(mode = "regression", cost = 1/4) |> + svm_linear(mode = "regression", cost = 1 / 4) |> set_engine("LiblineaR") |> set_mode("regression") cls_mod <- - svm_linear(mode = "classification", cost = 1/8) |> + svm_linear(mode = "classification", cost = 1 / 8) |> set_engine("LiblineaR") |> set_mode("classification") @@ -36,14 +49,13 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) # ------------------------------------------------------------------------------ test_that('linear svm regression: LiblineaR', { - skip_if_not_installed("LiblineaR") expect_no_condition( res <- fit_xy( reg_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$input_fields ) ) @@ -62,15 +74,13 @@ test_that('linear svm regression: LiblineaR', { control = ctrl ) ) - }) test_that('linear svm regression prediction: LiblineaR', { - skip_if_not_installed("LiblineaR") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -86,13 +96,16 @@ test_that('linear svm regression prediction: LiblineaR', { liblinear_pred <- structure( list(.pred = c(85.13979, 576.16232, 1886.10132)), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_pred <- predict(reg_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(liblinear_pred), - as.data.frame(parsnip_pred), - tolerance = .0001) - + expect_equal( + as.data.frame(liblinear_pred), + as.data.frame(parsnip_pred), + tolerance = .0001 + ) reg_xy_form <- fit_xy( @@ -101,21 +114,25 @@ test_that('linear svm regression prediction: LiblineaR', { y = hpc$input_fields, control = ctrl ) - expect_equal(extract_fit_engine(reg_form)$W, extract_fit_engine(reg_xy_form)$W) + expect_equal( + extract_fit_engine(reg_form)$W, + extract_fit_engine(reg_xy_form)$W + ) parsnip_xy_pred <- predict(reg_xy_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(liblinear_pred), - as.data.frame(parsnip_xy_pred), - tolerance = .0001) + expect_equal( + as.data.frame(liblinear_pred), + as.data.frame(parsnip_xy_pred), + tolerance = .0001 + ) }) # ------------------------------------------------------------------------------ test_that('linear svm classification: LiblineaR', { - skip_if_not_installed("LiblineaR") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -137,15 +154,13 @@ test_that('linear svm classification: LiblineaR', { control = ctrl ) ) - }) test_that('linear svm classification prediction: LiblineaR', { - skip_if_not_installed("LiblineaR") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(4, 55, 143) @@ -160,11 +175,17 @@ test_that('linear svm classification prediction: LiblineaR', { ) liblinear_class <- - structure(list( - .pred_class = structure( - c(1L, 1L, 2L), - .Label = c("VF", "F", "L"), class = "factor")), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + structure( + list( + .pred_class = structure( + c(1L, 1L, 2L), + .Label = c("VF", "F", "L"), + class = "factor" + ) + ), + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_class <- predict(cls_form, hpc_no_m[ind, -5]) expect_equal(liblinear_class, parsnip_class) @@ -177,7 +198,10 @@ test_that('linear svm classification prediction: LiblineaR', { y = hpc_no_m$class, control = ctrl ) - expect_equal(extract_fit_engine(cls_form)$W, extract_fit_engine(cls_xy_form)$W) + expect_equal( + extract_fit_engine(cls_form)$W, + extract_fit_engine(cls_xy_form)$W + ) expect_snapshot( error = TRUE, @@ -188,18 +212,17 @@ test_that('linear svm classification prediction: LiblineaR', { error = TRUE, predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") ) - }) # ------------------------------------------------------------------------------ reg_mod <- - svm_linear(mode = "regression", cost = 1/4) |> + svm_linear(mode = "regression", cost = 1 / 4) |> set_engine("kernlab") |> set_mode("regression") cls_mod <- - svm_linear(mode = "classification", cost = 1/8) |> + svm_linear(mode = "classification", cost = 1 / 8) |> set_engine("kernlab") |> set_mode("classification") @@ -208,14 +231,13 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) # ------------------------------------------------------------------------------ test_that('linear svm regression: kernlab', { - skip_if_not_installed("kernlab") expect_no_condition( res <- fit_xy( reg_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$input_fields ) ) @@ -230,15 +252,13 @@ test_that('linear svm regression: kernlab', { control = ctrl ) ) - }) test_that('linear svm regression prediction: kernlab', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -254,13 +274,16 @@ test_that('linear svm regression prediction: kernlab', { kernlab_pred <- structure( list(.pred = c(129.9097, 376.1049, 1032.8989)), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_pred <- predict(reg_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(kernlab_pred), - as.data.frame(parsnip_pred), - tolerance = .0001) - + expect_equal( + as.data.frame(kernlab_pred), + as.data.frame(parsnip_pred), + tolerance = .0001 + ) reg_xy_form <- fit_xy( @@ -269,21 +292,25 @@ test_that('linear svm regression prediction: kernlab', { y = hpc$input_fields, control = ctrl ) - expect_equal(extract_fit_engine(reg_form)@alphaindex, extract_fit_engine(reg_xy_form)@alphaindex) + expect_equal( + extract_fit_engine(reg_form)@alphaindex, + extract_fit_engine(reg_xy_form)@alphaindex + ) parsnip_xy_pred <- predict(reg_xy_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(kernlab_pred), - as.data.frame(parsnip_xy_pred), - tolerance = .0001) + expect_equal( + as.data.frame(kernlab_pred), + as.data.frame(parsnip_xy_pred), + tolerance = .0001 + ) }) # ------------------------------------------------------------------------------ test_that('linear svm classification: kernlab', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -305,15 +332,13 @@ test_that('linear svm classification: kernlab', { control = ctrl ) ) - }) test_that('linear svm classification prediction: kernlab', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(4, 55, 143) @@ -328,11 +353,17 @@ test_that('linear svm classification prediction: kernlab', { ) kernlab_class <- - structure(list( - .pred_class = structure( - c(1L, 1L, 3L), - .Label = c("VF", "F", "L"), class = "factor")), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + structure( + list( + .pred_class = structure( + c(1L, 1L, 3L), + .Label = c("VF", "F", "L"), + class = "factor" + ) + ), + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_class <- predict(cls_form, hpc_no_m[ind, -5]) expect_equal(kernlab_class, parsnip_class) @@ -345,11 +376,18 @@ test_that('linear svm classification prediction: kernlab', { y = hpc_no_m$class, control = ctrl ) - expect_equal(extract_fit_engine(cls_form)@alphaindex, extract_fit_engine(cls_xy_form)@alphaindex) + expect_equal( + extract_fit_engine(cls_form)@alphaindex, + extract_fit_engine(cls_xy_form)@alphaindex + ) library(kernlab) kern_probs <- - kernlab::predict(extract_fit_engine(cls_form), hpc_no_m[ind, -5], type = "probabilities") |> + kernlab::predict( + extract_fit_engine(cls_form), + hpc_no_m[ind, -5], + type = "probabilities" + ) |> as_tibble() |> setNames(c('.pred_VF', '.pred_F', '.pred_L')) @@ -358,7 +396,6 @@ test_that('linear svm classification prediction: kernlab', { parsnip_xy_probs <- predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs)) - }) test_that("check_args() works", { diff --git a/tests/testthat/test-svm_poly.R b/tests/testthat/test-svm_poly.R index 16286f521..ede9967f0 100644 --- a/tests/testthat/test-svm_poly.R +++ b/tests/testthat/test-svm_poly.R @@ -19,12 +19,12 @@ test_that('bad input', { # ------------------------------------------------------------------------------ reg_mod <- - svm_poly(mode = "regression", degree = 1, cost = 1/4) |> + svm_poly(mode = "regression", degree = 1, cost = 1 / 4) |> set_engine("kernlab") |> set_mode("regression") cls_mod <- - svm_poly(mode = "classification", degree = 2, cost = 1/8) |> + svm_poly(mode = "classification", degree = 2, cost = 1 / 8) |> set_engine("kernlab") |> set_mode("classification") @@ -33,14 +33,13 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) # ------------------------------------------------------------------------------ test_that('svm poly regression', { - skip_if_not_installed("kernlab") expect_no_condition( res <- fit_xy( reg_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$compounds ) ) @@ -56,12 +55,10 @@ test_that('svm poly regression', { control = ctrl ) ) - }) test_that('svm poly regression prediction', { - skip_if_not_installed("kernlab") reg_form <- @@ -79,16 +76,18 @@ test_that('svm poly regression prediction', { kern_pred <- structure( list( - .pred = c(164.4739, 139.8284, 133.8760)), - row.names = c(NA,-3L), + .pred = c(164.4739, 139.8284, 133.8760) + ), + row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame") ) parsnip_pred <- predict(reg_form, hpc[1:3, -c(1, 5)]) - expect_equal(as.data.frame(kern_pred), - as.data.frame(parsnip_pred), - tolerance = .0001) - + expect_equal( + as.data.frame(kern_pred), + as.data.frame(parsnip_pred), + tolerance = .0001 + ) reg_xy_form <- fit_xy( @@ -97,18 +96,22 @@ test_that('svm poly regression prediction', { y = hpc$compounds, control = ctrl ) - expect_equal(extract_fit_engine(reg_form)@alphaindex, extract_fit_engine(reg_xy_form)@alphaindex) + expect_equal( + extract_fit_engine(reg_form)@alphaindex, + extract_fit_engine(reg_xy_form)@alphaindex + ) parsnip_xy_pred <- predict(reg_xy_form, hpc[1:3, -c(1, 5)]) - expect_equal(as.data.frame(kern_pred), - as.data.frame(parsnip_xy_pred), - tolerance = .0001) + expect_equal( + as.data.frame(kern_pred), + as.data.frame(parsnip_xy_pred), + tolerance = .0001 + ) }) # ------------------------------------------------------------------------------ test_that('svm poly classification', { - skip_if_not_installed("kernlab") expect_no_condition( @@ -128,15 +131,13 @@ test_that('svm poly classification', { control = ctrl ) ) - }) test_that('svm poly classification probabilities', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(1, 2, 143) @@ -155,8 +156,11 @@ test_that('svm poly classification probabilities', { kern_class <- structure( list( - .pred_class = .pred_factor), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + .pred_class = .pred_factor + ), + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_class <- predict(cls_form, hpc_no_m[ind, -5]) expect_equal(kern_class, parsnip_class) @@ -169,11 +173,18 @@ test_that('svm poly classification probabilities', { y = hpc_no_m$class, control = ctrl ) - expect_equal(extract_fit_engine(cls_form)@alphaindex, extract_fit_engine(cls_xy_form)@alphaindex) + expect_equal( + extract_fit_engine(cls_form)@alphaindex, + extract_fit_engine(cls_xy_form)@alphaindex + ) library(kernlab) kern_probs <- - kernlab::predict(extract_fit_engine(cls_form), hpc_no_m[ind, -5], type = "probabilities") |> + kernlab::predict( + extract_fit_engine(cls_form), + hpc_no_m[ind, -5], + type = "probabilities" + ) |> as_tibble() |> setNames(c('.pred_VF', '.pred_F', '.pred_L')) diff --git a/tests/testthat/test-svm_rbf.R b/tests/testthat/test-svm_rbf.R index 9940d249a..86a5734c0 100644 --- a/tests/testthat/test-svm_rbf.R +++ b/tests/testthat/test-svm_rbf.R @@ -5,7 +5,8 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ test_that('engine arguments', { - kernlab_cv <- svm_rbf(mode = "regression") |> set_engine("kernlab", cross = 10) + kernlab_cv <- svm_rbf(mode = "regression") |> + set_engine("kernlab", cross = 10) expect_snapshot(translate(kernlab_cv, "kernlab")$method$fit$args) }) @@ -21,18 +22,21 @@ test_that('updating', { test_that('bad input', { expect_snapshot(error = TRUE, svm_rbf(mode = "reallyunknown")) - expect_snapshot(error = TRUE, translate(svm_rbf(mode = "regression") |> set_engine( NULL))) + expect_snapshot( + error = TRUE, + translate(svm_rbf(mode = "regression") |> set_engine(NULL)) + ) }) # ------------------------------------------------------------------------------ reg_mod <- - svm_rbf(mode = "regression", rbf_sigma = .1, cost = 1/4) |> + svm_rbf(mode = "regression", rbf_sigma = .1, cost = 1 / 4) |> set_engine("kernlab") |> set_mode("regression") cls_mod <- - svm_rbf(mode = "classification", rbf_sigma = .1, cost = 1/8) |> + svm_rbf(mode = "classification", rbf_sigma = .1, cost = 1 / 8) |> set_engine("kernlab") |> set_mode("classification") @@ -41,14 +45,13 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) # ------------------------------------------------------------------------------ test_that('svm poly regression', { - skip_if_not_installed("kernlab") expect_no_condition( res <- fit_xy( reg_mod, control = ctrl, - x = hpc[,2:4], + x = hpc[, 2:4], y = hpc$input_fields ) ) @@ -63,15 +66,13 @@ test_that('svm poly regression', { control = ctrl ) ) - }) test_that('svm rbf regression prediction', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -87,13 +88,16 @@ test_that('svm rbf regression prediction', { kern_pred <- structure( list(.pred = c(131.7743, 372.0932, 902.0633)), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_pred <- predict(reg_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(kern_pred), - as.data.frame(parsnip_pred), - tolerance = .0001) - + expect_equal( + as.data.frame(kern_pred), + as.data.frame(parsnip_pred), + tolerance = .0001 + ) reg_xy_form <- fit_xy( @@ -102,21 +106,25 @@ test_that('svm rbf regression prediction', { y = hpc$input_fields, control = ctrl ) - expect_equal(extract_fit_engine(reg_form)@alphaindex, extract_fit_engine(reg_xy_form)@alphaindex) + expect_equal( + extract_fit_engine(reg_form)@alphaindex, + extract_fit_engine(reg_xy_form)@alphaindex + ) parsnip_xy_pred <- predict(reg_xy_form, hpc[ind, -c(2, 5)]) - expect_equal(as.data.frame(kern_pred), - as.data.frame(parsnip_xy_pred), - tolerance = .0001) + expect_equal( + as.data.frame(kern_pred), + as.data.frame(parsnip_xy_pred), + tolerance = .0001 + ) }) # ------------------------------------------------------------------------------ test_that('svm rbf classification', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(2, 1, 143) @@ -138,15 +146,13 @@ test_that('svm rbf classification', { control = ctrl ) ) - }) test_that('svm rbf classification probabilities', { - skip_if_not_installed("kernlab") - hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128),] |> + hpc_no_m <- hpc[-c(84, 85, 86, 87, 88, 109, 128), ] |> droplevels() ind <- c(4, 55, 143) @@ -161,11 +167,17 @@ test_that('svm rbf classification probabilities', { ) kern_class <- - structure(list( - .pred_class = structure( - c(1L, 1L, 3L), - .Label = c("VF", "F", "L"), class = "factor")), - row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")) + structure( + list( + .pred_class = structure( + c(1L, 1L, 3L), + .Label = c("VF", "F", "L"), + class = "factor" + ) + ), + row.names = c(NA, -3L), + class = c("tbl_df", "tbl", "data.frame") + ) parsnip_class <- predict(cls_form, hpc_no_m[ind, -5]) expect_equal(kern_class, parsnip_class) @@ -182,7 +194,11 @@ test_that('svm rbf classification probabilities', { library(kernlab) kern_probs <- - kernlab::predict(extract_fit_engine(cls_form), hpc_no_m[ind, -5], type = "probabilities") |> + kernlab::predict( + extract_fit_engine(cls_form), + hpc_no_m[ind, -5], + type = "probabilities" + ) |> as_tibble() |> setNames(c('.pred_VF', '.pred_F', '.pred_L')) diff --git a/tests/testthat/test-tunable.R b/tests/testthat/test-tunable.R index dfc011211..9d953ef09 100644 --- a/tests/testthat/test-tunable.R +++ b/tests/testthat/test-tunable.R @@ -4,16 +4,18 @@ test_that('brulee has mixture object', { mlp_spec <- mlp( hidden_units = tune(), - activation = tune(), + activation = tune(), penalty = tune(), learn_rate = tune(), epoch = 2000 ) |> set_mode("regression") |> - set_engine("brulee", - stop_iter = tune(), - mixture = tune(), - rate_schedule = tune()) + set_engine( + "brulee", + stop_iter = tune(), + mixture = tune(), + rate_schedule = tune() + ) brulee_res <- tunable(mlp_spec) diff --git a/tests/testthat/test-update.R b/tests/testthat/test-update.R index b4bd1e780..121ffc1ab 100644 --- a/tests/testthat/test-update.R +++ b/tests/testthat/test-update.R @@ -1,11 +1,12 @@ test_that("update methods work (eg: linear_reg)", { expr1 <- linear_reg() |> set_engine("lm", model = FALSE) expr2 <- linear_reg() |> set_engine("glmnet", nlambda = tune()) - expr3 <- linear_reg(mixture = 0, penalty = tune()) |> set_engine("glmnet", nlambda = tune()) + expr3 <- linear_reg(mixture = 0, penalty = tune()) |> + set_engine("glmnet", nlambda = tune()) expr4 <- linear_reg(mixture = 0) |> set_engine("glmnet", nlambda = 10) expr5 <- linear_reg() |> set_engine("glm", family = "gaussian") - param_tibb <- tibble::tibble(mixture = 1/2, penalty = 1) + param_tibb <- tibble::tibble(mixture = 1 / 2, penalty = 1) param_list <- as.list(param_tibb) expect_snapshot(expr1 |> update(mixture = 0)) @@ -27,13 +28,16 @@ test_that("update methods prompt informatively", { # engine arguments passed in param expr1 <- linear_reg(mixture = 0) |> set_engine("glmnet", nlambda = 10) - param_tibb <- tibble::tibble(mixture = 1/2, nlambda = 5) + param_tibb <- tibble::tibble(mixture = 1 / 2, nlambda = 5) param_list <- as.list(param_tibb) expect_snapshot(error = TRUE, expr1 |> update(param_tibb)) expect_snapshot(error = TRUE, expr1 |> update(param_list)) expect_snapshot(error = TRUE, expr1 |> update(parameters = "wat")) - expect_snapshot(error = TRUE, expr1 |> update(parameters = tibble::tibble(wat = "wat"))) + expect_snapshot( + error = TRUE, + expr1 |> update(parameters = tibble::tibble(wat = "wat")) + ) # nonexistent main or eng args expect_snapshot(error = TRUE, linear_reg() |> update(boop = 0)) diff --git a/tests/testthat/test-varying.R b/tests/testthat/test-varying.R index d04cc4c4f..d69efd4c1 100644 --- a/tests/testthat/test-varying.R +++ b/tests/testthat/test-varying.R @@ -20,7 +20,7 @@ test_that('main parsnip arguments', { exp_2$varying[1] <- TRUE expect_equal(mod_2, exp_2) - mod_3 <- rand_forest(mtry = varying(), trees = varying()) |> + mod_3 <- rand_forest(mtry = varying(), trees = varying()) |> varying_args() exp_3 <- exp_2 @@ -45,7 +45,7 @@ test_that('other parsnip arguments', { expect_equal(other_1, exp_1) - other_2 <- rand_forest(min_n = varying()) |> + other_2 <- rand_forest(min_n = varying()) |> set_engine("ranger", sample.fraction = varying()) |> varying_args() @@ -60,20 +60,20 @@ test_that('other parsnip arguments', { # We can detect these as varying, but they won't actually # be used in this way - other_3 <- rand_forest() |> + other_3 <- rand_forest() |> set_engine("ranger", strata = Class, sampsize = c(varying(), varying())) |> varying_args() exp_3 <- tibble( - name = c("mtry", "trees", "min_n", "strata", "sampsize"), - varying = c(rep(FALSE, 4), TRUE), - id = rep("rand_forest", 5), - type = rep("model_spec", 5) - ) + name = c("mtry", "trees", "min_n", "strata", "sampsize"), + varying = c(rep(FALSE, 4), TRUE), + id = rep("rand_forest", 5), + type = rep("model_spec", 5) + ) expect_equal(other_3, exp_3) - other_4 <- rand_forest() |> + other_4 <- rand_forest() |> set_engine("ranger", strata = Class, sampsize = c(12, varying())) |> varying_args()