Skip to content

Commit a517c87

Browse files
authored
Setting the penalty for logistic and multinomial regression with glmnet (#863)
* setting penalty in `predict_raw()` method so that it: - also gets applied in `predict(type = "raw")` - structure follows that of `linear_reg()`, which is also laid out in the comments * sticking to the general pattern
1 parent 2249cbb commit a517c87

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.3.9000
3+
Version: 1.0.3.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"),

R/logistic_reg.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,6 @@ multi_predict._lognet <-
271271
}
272272
}
273273

274-
dots$s <- penalty
275-
276274
if (is.null(type))
277275
type <- "class"
278276
if (!(type %in% c("class", "prob", "link", "raw"))) {
@@ -284,7 +282,9 @@ multi_predict._lognet <-
284282
dots$type <- type
285283

286284
object$spec <- eval_args(object$spec)
287-
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
285+
pred <- predict._lognet(object, new_data = new_data, type = "raw",
286+
opts = dots, penalty = penalty, multi = TRUE)
287+
288288
param_key <- tibble(group = colnames(pred), penalty = penalty)
289289
pred <- as_tibble(pred)
290290
pred$.row <- 1:nrow(pred)
@@ -340,6 +340,7 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
340340
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
341341

342342
object$spec <- eval_args(object$spec)
343+
opts$s <- object$spec$args$penalty
343344
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
344345
}
345346

R/multinom_reg.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ multi_predict._multnet <-
200200
penalty <- eval_tidy(penalty)
201201

202202
dots <- list(...)
203+
203204
if (is.null(penalty)) {
204205
# See discussion in https://github.com/tidymodels/parsnip/issues/195
205206
if (!is.null(object$spec$args$penalty)) {
@@ -208,7 +209,6 @@ multi_predict._multnet <-
208209
penalty <- object$fit$lambda
209210
}
210211
}
211-
dots$s <- penalty
212212

213213
if (is.null(type))
214214
type <- "class"
@@ -221,7 +221,8 @@ multi_predict._multnet <-
221221
dots$type <- type
222222

223223
object$spec <- eval_args(object$spec)
224-
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
224+
pred <- predict._multnet(object, new_data = new_data, type = "raw",
225+
opts = dots, penalty = penalty, multi = TRUE)
225226

226227
format_probs <- function(x) {
227228
x <- as_tibble(x)
@@ -268,5 +269,6 @@ predict_classprob._multnet <- function(object, new_data, ...) {
268269
#' @export
269270
predict_raw._multnet <- function(object, new_data, opts = list(), ...) {
270271
object$spec <- eval_args(object$spec)
272+
opts$s <- object$spec$args$penalty
271273
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
272274
}

0 commit comments

Comments
 (0)