The functions .organize_glmnet_pred(), organize_glmnet_class(), and organize_glmnet_prob() are used in the post hook of the prediction module for the glmnet engines to linear_reg() and logistic_reg().
For example, in line 246
|
set_pred( |
|
model = "linear_reg", |
|
eng = "glmnet", |
|
mode = "regression", |
|
type = "numeric", |
|
value = list( |
|
pre = NULL, |
|
post = .organize_glmnet_pred, |
|
func = c(fun = "predict"), |
|
args = |
|
list( |
|
object = expr(object$fit), |
|
newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])), |
|
type = "response", |
|
s = expr(object$spec$args$penalty) |
|
) |
|
) |
|
) |
Their job is to reformat the predictions into a format that format_num(), format_class, or format_classprobs() can work with, called inside of predict.model_fit().
Since they are only ever called in this way, here in parsnip and the exported .organize_glmnet_pred() in censored and poissonreg, they will always deal with predictions for a single penalty value. They do contain code to deal with predictions for multiple penalty values but we don't need that here. So I suggest we remove that here.
Why only ever a single penalty value? Because we check for that with .check_glmnet_penalty_predict(). It is possible to get around this check by setting multi = TRUE which is used in combination with type = "raw" inside of multi_predict(). Type "raw" ensures that we don't call any of the post hook functions so the predictions don't go through the organize_glmnet*() functions.
It is possible to get around this check but that is offlabel usage and I think we would be okay not supporting that in exchange for simpler post hook functions.
library(parsnip)
data("hpc_data", package = "modeldata")
hpc <- hpc_data[1:150, c(2:5, 8)]
f_fit <- linear_reg(penalty = 0.1, mixture = 0.3) %>%
set_engine("glmnet", nlambda = 15) %>%
fit(input_fields ~ log(compounds) + class, data = hpc)
# regular usage errors informatively
predict(f_fit, hpc[1:3,], penalty = 1:2)
#> Error in `.check_glmnet_penalty_predict()`:
#> ! `penalty` should be a single numeric value. `multi_predict()` can be used to get multiple predictions per row of data.
# off-label usage
predict(f_fit, hpc[1:3,], penalty = 1:2, multi = TRUE)
#> # A tibble: 6 × 2
#> .pred_values .pred_lambda
#> <dbl> <int>
#> 1 570. 1
#> 2 163. 1
#> 3 167. 1
#> 4 570. 2
#> 5 163. 2
#> 6 168. 2
Created on 2023-02-22 with reprex v2.0.2
The functions
.organize_glmnet_pred(),organize_glmnet_class(), andorganize_glmnet_prob()are used in the post hook of the prediction module for the glmnet engines tolinear_reg()andlogistic_reg().For example, in line 246
parsnip/R/linear_reg_data.R
Lines 239 to 256 in a482442
Their job is to reformat the predictions into a format that
format_num(),format_class, orformat_classprobs()can work with, called inside ofpredict.model_fit().Since they are only ever called in this way, here in parsnip and the exported
.organize_glmnet_pred()in censored and poissonreg, they will always deal with predictions for a single penalty value. They do contain code to deal with predictions for multiple penalty values but we don't need that here. So I suggest we remove that here.Why only ever a single penalty value? Because we check for that with
.check_glmnet_penalty_predict(). It is possible to get around this check by settingmulti = TRUEwhich is used in combination withtype = "raw"inside ofmulti_predict(). Type"raw"ensures that we don't call any of the post hook functions so the predictions don't go through theorganize_glmnet*()functions.It is possible to get around this check but that is offlabel usage and I think we would be okay not supporting that in exchange for simpler post hook functions.
Created on 2023-02-22 with reprex v2.0.2