-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support glmnet models with base-R families #890
Conversation
there objects result from using base-R families
because `type = "class"` is not available in glmnet for `glmnetfit` objects
if (type == "class") { | ||
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = lvl) | ||
pred <- pred %>% | ||
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apart from this line adding the translation from probabilities to classes, the changes to this function are only to make use of dplyr and tidyr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do-er of diligence! Effectuator of exactness! Agent of accuracy!
This PR is super solid—no comments. Thanks for the helpful description.🍄
res <- switch( | ||
model_type, | ||
"linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty), | ||
"logistic_reg" = format_glmnet_multi_logistic_reg(pred, | ||
penalty = penalty, | ||
type = dots$type, | ||
type = type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note to help speed up @topepo's review: format_glmnet_multi_logistic_reg()
used to take whichever type
was passed along to glmnet, and its internals read:
if (type == "class") {
# ...
} else {
# ...
}
That helper is now supplied the "parsnip" type instead, in this case one of "class"
or "prob"
, so that the above conditional has only one possible value for the else
. The format_glmnet_multi_multinom_reg()
helper does this already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@simonpcouch Thanks for adding this! 🙌
Merge commit '070f1b2ccf92ae95aa16dc34153c9aac874d6ccf' #Conflicts: # NEWS.md # R/glmnet.R
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
Closes #483 and closes #738
Since version 4.0, glmnet supports base-R families for the
family
argument. The resulting fitted model objects all get a new classglmnetfit
which is not supported in parsnip yet.We need several predict methods for this new class that check and set the penalty argument to work with the rest of parsnip.
We also need to translate between parsnip's prediction type (
"numeric"
,"class"
,"prob"
, etc) and glmnet's prediction type ("response"
,"class"
, etc). The gotcha moment here is thattype = "class"
exists for logistic and multinomial regression models fitted with glmnet's built-in families, resulting in objects of classeslognet
andmultnet
, but does not exist for models fitted with any of the base-R families, resulting in objects of classglmnetfit
.For
glmnetfit
models, we do havetype = "response"
available, which, for logistic and multinomial regression, gives us the class probabilities which we can translate to hard class predictions.We already do that for
predict()
on logistic regression, see here and here, but not formulti_predict()
. For multinomial regression, we usetype = "class"
for bothpredict()
andmulti_predict()
.I have contemplated if we should try to use
type = "class"
wherever we can and then special-case theglmnetfit
models or if we should lean the other way and usetype = "response"
consistently and translate from probabilities to class predictions even in cases where we could get that from glmnet. I have landed in a pragmatic middle with this PR.For multinomial regression, I don't see a corresponding suitable base-R family, so there is no added complexity via
glmnetfit
objects. I have therefore left this as is for the time being.For logistic regression, this PR leans into "use type 'response' everywhere" because special-casing
glmnetfit
object would be a particular challenge: Forpredict()
, the translation between parsnip's prediction type and glmnet's prediction type happens whenpredict_<type>.model_fit()
constructs the call to glmnet's prediction method, using the type defined inset_pred()
. There isn't really a (good?) way to set that conditional on the class of the model fit object. The (only? canonical?) way to manually take control of the translation between the prediction types is whatmulti_predict()
does: it uses the parsniptype = "raw"
and sets glmnet'stype
as part of theopts
which are passed down to glmnet's predict method. Trying to do this forpredict()
in addition tomulti_predict()
would defy the whole purpose of parsnip's predict machinery. Hence, no special-casing theglmnetfit
objects but rathertype = "response"
everywhere and taking control of how the predictions are formatted at the end (aka turn probabilities into class predictions or not).The corresponding tests are with the other glmnet-related tests in extratests: tidymodels/extratests#77