diff --git a/R/glmnet.R b/R/glmnet.R new file mode 100644 index 000000000..598ab719e --- /dev/null +++ b/R/glmnet.R @@ -0,0 +1,154 @@ +# glmnet call stack using `predict()` when object has +# classes "_" and "model_fit": +# +# predict() +# predict._(penalty = NULL) +# predict_glmnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_numeric() +# predict_numeric._() +# predict_numeric_glmnet() +# predict_numeric.model_fit() +# predict.() + + +# glmnet call stack using `multi_predict` when object has +# classes "_" and "model_fit": +# +# multi_predict() +# multi_predict._(penalty = NULL) +# predict._(multi = TRUE) +# predict_glmnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._() +# predict_raw_glmnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.() + + +predict_glmnet <- function(object, + new_data, + type = NULL, + opts = list(), + penalty = NULL, + multi = FALSE, + ...) { + + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (is.null(penalty) & !is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } + + object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) + + object$spec <- eval_args(object$spec) + predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) +} + +predict_numeric_glmnet <- function(object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + object$spec <- eval_args(object$spec) + predict_numeric.model_fit(object, new_data = new_data, ...) +} + +predict_class_glmnet <- function(object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +predict_classprob_glmnet <- function(object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +predict_raw_glmnet <- function(object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + object$spec <- eval_args(object$spec) + + opts$s <- object$spec$args$penalty + + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + +multi_predict_glmnet <- function(object, + new_data, + type = NULL, + penalty = NULL, + ...) { + + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + if (object$spec$mode == "classification") { + if (is_quosure(penalty)) { + penalty <- eval_tidy(penalty) + } + } + + dots <- list(...) + + object$spec <- eval_args(object$spec) + + if (is.null(penalty)) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (!is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } else { + penalty <- object$fit$lambda + } + } + + if (object$spec$mode == "classification") { + if (is.null(type)) { + type <- "class" + } + if (!(type %in% c("class", "prob", "link", "raw"))) { + rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") + } + if (type == "prob") { + dots$type <- "response" + } else { + dots$type <- type + } + } + + pred <- predict(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) + + model_type <- class(object$spec)[1] + res <- switch( + model_type, + "linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty), + "logistic_reg" = format_glmnet_multi_logistic_reg(pred, + penalty = penalty, + type = dots$type, + lvl = object$lvl), + "multinom_reg" = format_glmnet_multi_multinom_reg(pred, + penalty = penalty, + type = type, + n_rows = nrow(new_data), + lvl = object$lvl) + ) + + res +} diff --git a/R/linear_reg.R b/R/linear_reg.R index 34243b905..00b17baac 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -156,94 +156,19 @@ check_args.linear_reg <- function(object) { res } -# ------------------------------------------------------------------------------ -# glmnet call stack for linear regression using `predict` when object has -# classes "_elnet" and "model_fit": -# -# predict() -# predict._elnet(penalty = NULL) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_numeric() -# predict_numeric._elnet() -# predict_numeric.model_fit() -# predict.elnet() - - -# glmnet call stack for linear regression using `multi_predict` when object has -# classes "_elnet" and "model_fit": -# -# multi_predict() -# multi_predict._elnet(penalty = NULL) -# predict._elnet(multi = TRUE) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_raw() -# predict_raw._elnet() -# predict_raw.model_fit(opts = list(s = penalty)) -# predict.elnet() - - #' @export -predict._elnet <- - function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (is.null(penalty) & !is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } - - object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) - - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) - } +predict._elnet <- predict_glmnet #' @export -predict_numeric._elnet <- function(object, new_data, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - object$spec <- eval_args(object$spec) - predict_numeric.model_fit(object, new_data = new_data, ...) -} +predict_numeric._elnet <- predict_numeric_glmnet #' @export -predict_raw._elnet <- function(object, new_data, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - object$spec <- eval_args(object$spec) - opts$s <- object$spec$args$penalty - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} +predict_raw._elnet <- predict_raw_glmnet #' @export #'@rdname multi_predict #' @param penalty A numeric vector of penalty values. -multi_predict._elnet <- - function(object, new_data, type = NULL, penalty = NULL, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - dots <- list(...) - - object$spec <- eval_args(object$spec) - - if (is.null(penalty)) { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (!is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } else { - penalty <- object$fit$lambda - } - } - - pred <- predict._elnet(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) - - format_glmnet_multi_linear_reg(pred, penalty = penalty) - } +multi_predict._elnet <- multi_predict_glmnet format_glmnet_multi_linear_reg <- function(pred, penalty) { param_key <- tibble(group = colnames(pred), penalty = penalty) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index c5f8c3f51..0e805f360 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -206,92 +206,14 @@ organize_glmnet_prob <- function(x, object) { res } -# ------------------------------------------------------------------------------ -# glmnet call stack for logistic regression using `predict` when object has -# classes "_lognet" and "model_fit" (for class predictions): -# -# predict() -# predict._lognet(penalty = NULL) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_class() -# predict_class._lognet() -# predict_class.model_fit() -# predict.lognet() - - -# glmnet call stack for logistic regression using `multi_predict` when object has -# classes "_lognet" and "model_fit" (for class predictions): -# -# multi_predict() -# multi_predict._lognet(penalty = NULL) -# predict._lognet(multi = TRUE) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_raw() -# predict_raw._lognet() -# predict_raw.model_fit(opts = list(s = penalty)) -# predict.lognet() - # ------------------------------------------------------------------------------ #' @export -predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (is.null(penalty) & !is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } - - object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) - - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - +predict._lognet <- predict_glmnet #' @export #' @rdname multi_predict -multi_predict._lognet <- - function(object, new_data, type = NULL, penalty = NULL, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - if (is_quosure(penalty)) - penalty <- eval_tidy(penalty) - - dots <- list(...) - - if (is.null(penalty)) { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (!is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } else { - penalty <- object$fit$lambda - } - } - - if (is.null(type)) - type <- "class" - if (!(type %in% c("class", "prob", "link", "raw"))) { - rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") - } - if (type == "prob") - dots$type <- "response" - else - dots$type <- type - - object$spec <- eval_args(object$spec) - pred <- predict._lognet(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) - - format_glmnet_multi_logistic_reg( - pred, - penalty, - type = dots$type, - lvl = object$lvl - ) - } +multi_predict._lognet <- multi_predict_glmnet format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) { param_key <- tibble(group = colnames(pred), penalty = penalty) @@ -324,32 +246,13 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) { #' @export -predict_class._lognet <- function(object, new_data, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} +predict_class._lognet <- predict_class_glmnet #' @export -predict_classprob._lognet <- function(object, new_data, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} +predict_classprob._lognet <- predict_classprob_glmnet #' @export -predict_raw._lognet <- function(object, new_data, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - object$spec <- eval_args(object$spec) - opts$s <- object$spec$args$penalty - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} +predict_raw._lognet <- predict_raw_glmnet # ------------------------------------------------------------------------------ diff --git a/R/multinom_reg.R b/R/multinom_reg.R index dacd88b3e..465ebf205 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -142,115 +142,23 @@ organize_nnet_prob <- function(x, object) { -# ------------------------------------------------------------------------------ -# glmnet call stack for multinomial regression using `predict` when object has -# classes "_multnet" and "model_fit" (for class predictions): -# -# predict() -# predict._multnet(penalty = NULL) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_class() -# predict_class._multnet() -# predict.multnet() - - -# glmnet call stack for multinomial regression using `multi_predict` when object has -# classes "_multnet" and "model_fit" (for class predictions): -# -# multi_predict() -# multi_predict._multnet(penalty = NULL) -# predict._multnet(multi = TRUE) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_raw() -# predict_raw._multnet() -# predict_raw.model_fit(opts = list(s = penalty)) -# predict.multnet() - # ------------------------------------------------------------------------------ #' @export -predict._multnet <- - function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { - - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (is.null(penalty) & !is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } - - object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) - - object$spec <- eval_args(object$spec) - res <- predict.model_fit( - object = object, - new_data = new_data, - type = type, - opts = opts - ) - res - } +predict._multnet <- predict_glmnet #' @export #' @rdname multi_predict -multi_predict._multnet <- - function(object, new_data, type = NULL, penalty = NULL, ...) { - if (any(names(enquos(...)) == "newdata")) - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - - if (is_quosure(penalty)) - penalty <- eval_tidy(penalty) - - dots <- list(...) - - if (is.null(penalty)) { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (!is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } else { - penalty <- object$fit$lambda - } - } - - if (is.null(type)) - type <- "class" - if (!(type %in% c("class", "prob", "link", "raw"))) { - rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") - } - if (type == "prob") - dots$type <- "response" - else - dots$type <- type - - object$spec <- eval_args(object$spec) - pred <- predict._multnet(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) - - format_glmnet_multi_multinom_reg( - pred, - penalty = penalty, - type = type, - n_rows = nrow(new_data), - lvl = object$lvl - ) - } +multi_predict._multnet <- multi_predict_glmnet #' @export -predict_class._multnet <- function(object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} +predict_class._multnet <- predict_class_glmnet #' @export -predict_classprob._multnet <- function(object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} +predict_classprob._multnet <- predict_classprob_glmnet #' @export -predict_raw._multnet <- function(object, new_data, opts = list(), ...) { - object$spec <- eval_args(object$spec) - opts$s <- object$spec$args$penalty - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} +predict_raw._multnet <- predict_raw_glmnet format_glmnet_multi_multinom_reg <- function(pred, penalty, type, n_rows, lvl) { format_probs <- function(x) {