From 76ba42cfd543b10ba585654da640a5cf1ce59add Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 22 Feb 2023 15:28:29 +0000 Subject: [PATCH 1/2] simplify glmnet formatting helpers --- R/linear_reg.R | 13 +------------ R/logistic_reg.R | 28 +++------------------------- 2 files changed, 4 insertions(+), 37 deletions(-) diff --git a/R/linear_reg.R b/R/linear_reg.R index 00b17baac..0ec90179d 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -142,18 +142,7 @@ check_args.linear_reg <- function(object) { #' @keywords internal #' @export .organize_glmnet_pred <- function(x, object) { - if (ncol(x) == 1) { - res <- x[, 1] - res <- unname(res) - } else { - n <- nrow(x) - res <- utils::stack(as.data.frame(x)) - if (!is.null(object$spec$args$penalty)) - res$lambda <- rep(object$spec$args$penalty, each = n) else - res$lambda <- rep(object$fit$lambda, each = n) - res <- res[, colnames(res) %in% c("values", "lambda")] - } - res + unname(x[, 1]) } #' @export diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 0e805f360..26d93dc1d 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -174,35 +174,13 @@ prob_to_class_2 <- function(x, object) { unname(x) } - organize_glmnet_class <- function(x, object) { - if (ncol(x) == 1) { - res <- prob_to_class_2(x[, 1], object) - } else { - n <- nrow(x) - res <- utils::stack(as.data.frame(x)) - res$values <- prob_to_class_2(res$values, object) - if (!is.null(object$spec$args$penalty)) - res$lambda <- rep(object$spec$args$penalty, each = n) else - res$lambda <- rep(object$fit$lambda, each = n) - res <- res[, colnames(res) %in% c("values", "lambda")] - } - res + prob_to_class_2(x[, 1], object) } organize_glmnet_prob <- function(x, object) { - if (ncol(x) == 1) { - res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1]) - colnames(res) <- object$lvl - } else { - n <- nrow(x) - res <- utils::stack(as.data.frame(x)) - res <- tibble(v1 = 1 - res$values, v2 = res$values) - colnames(res) <- object$lvl - if (!is.null(object$spec$args$penalty)) - res$lambda <- rep(object$spec$args$penalty, each = n) else - res$lambda <- rep(object$fit$lambda, each = n) - } + res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1]) + colnames(res) <- object$lvl res } From 02de86043259d5c2c1c52953f3bceb9f135a3d6a Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 23 Feb 2023 13:42:36 +0000 Subject: [PATCH 2/2] add NEWS entry for the exported function --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 32fc011fe..bcf1cd289 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* `.organize_glmnet_pred()` now expects predictions for a single penalty value (#876). + # parsnip 1.0.4 * For censored regression models, a "reverse Kaplan-Meier" curve is computed for the censoring distribution. This can be used when evaluating this type of model (#855).