From 1b33a6eda723dbe0275ad3b88c83bcb5ce2ddfa4 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 25 Jan 2023 21:26:21 +0000 Subject: [PATCH 1/2] setting penalty in `predict_raw()` method so that it: - also gets applied in `predict(type = "raw")` - structure follows that of `linear_reg()`, which is also laid out in the comments --- DESCRIPTION | 2 +- R/logistic_reg.R | 7 ++++--- R/multinom_reg.R | 6 ++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1fd35efef..3c6feb9ba 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.3.9000 +Version: 1.0.3.9001 Authors@R: c( person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"), diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 86deb54fe..0b7728b5f 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -271,8 +271,6 @@ multi_predict._lognet <- } } - dots$s <- penalty - if (is.null(type)) type <- "class" if (!(type %in% c("class", "prob", "link", "raw"))) { @@ -284,7 +282,9 @@ multi_predict._lognet <- dots$type <- type object$spec <- eval_args(object$spec) - pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots) + pred <- predict._lognet(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) + param_key <- tibble(group = colnames(pred), penalty = penalty) pred <- as_tibble(pred) pred$.row <- 1:nrow(pred) @@ -340,6 +340,7 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) { rlang::abort("Did you mean to use `new_data` instead of `newdata`?") object$spec <- eval_args(object$spec) + opts$s <- object$spec$args$penalty predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 251f52ec4..1293255d0 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -200,6 +200,7 @@ multi_predict._multnet <- penalty <- eval_tidy(penalty) dots <- list(...) + if (is.null(penalty)) { # See discussion in https://github.com/tidymodels/parsnip/issues/195 if (!is.null(object$spec$args$penalty)) { @@ -208,7 +209,6 @@ multi_predict._multnet <- penalty <- object$fit$lambda } } - dots$s <- penalty if (is.null(type)) type <- "class" @@ -221,7 +221,8 @@ multi_predict._multnet <- dots$type <- type object$spec <- eval_args(object$spec) - pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots) + pred <- predict._lognet(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) format_probs <- function(x) { x <- as_tibble(x) @@ -268,5 +269,6 @@ predict_classprob._multnet <- function(object, new_data, ...) { #' @export predict_raw._multnet <- function(object, new_data, opts = list(), ...) { object$spec <- eval_args(object$spec) + opts$s <- object$spec$args$penalty predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } From b652e0d17e102d34397e048872822dc715c68519 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 25 Jan 2023 21:37:18 +0000 Subject: [PATCH 2/2] sticking to the general pattern --- R/multinom_reg.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 1293255d0..64bb10774 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -221,8 +221,8 @@ multi_predict._multnet <- dots$type <- type object$spec <- eval_args(object$spec) - pred <- predict._lognet(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) + pred <- predict._multnet(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) format_probs <- function(x) { x <- as_tibble(x)