Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support glmnet models with base-R families #890

Merged
merged 6 commits into from Mar 2, 2023
Merged

Conversation

hfrick
Copy link
Member

@hfrick hfrick commented Feb 23, 2023

Closes #483 and closes #738

Since version 4.0, glmnet supports base-R families for the family argument. The resulting fitted model objects all get a new class glmnetfit which is not supported in parsnip yet.

We need several predict methods for this new class that check and set the penalty argument to work with the rest of parsnip.

We also need to translate between parsnip's prediction type ("numeric", "class", "prob", etc) and glmnet's prediction type ("response", "class", etc). The gotcha moment here is that type = "class" exists for logistic and multinomial regression models fitted with glmnet's built-in families, resulting in objects of classes lognet and multnet, but does not exist for models fitted with any of the base-R families, resulting in objects of class glmnetfit.

For glmnetfit models, we do have type = "response" available, which, for logistic and multinomial regression, gives us the class probabilities which we can translate to hard class predictions.

We already do that for predict() on logistic regression, see here and here, but not for multi_predict(). For multinomial regression, we use type = "class" for both predict() and multi_predict().

I have contemplated if we should try to use type = "class" wherever we can and then special-case the glmnetfit models or if we should lean the other way and use type = "response" consistently and translate from probabilities to class predictions even in cases where we could get that from glmnet. I have landed in a pragmatic middle with this PR.

For multinomial regression, I don't see a corresponding suitable base-R family, so there is no added complexity via glmnetfit objects. I have therefore left this as is for the time being.

For logistic regression, this PR leans into "use type 'response' everywhere" because special-casing glmnetfit object would be a particular challenge: For predict(), the translation between parsnip's prediction type and glmnet's prediction type happens when predict_<type>.model_fit() constructs the call to glmnet's prediction method, using the type defined in set_pred(). There isn't really a (good?) way to set that conditional on the class of the model fit object. The (only? canonical?) way to manually take control of the translation between the prediction types is what multi_predict() does: it uses the parsnip type = "raw" and sets glmnet's type as part of the opts which are passed down to glmnet's predict method. Trying to do this for predict() in addition to multi_predict() would defy the whole purpose of parsnip's predict machinery. Hence, no special-casing the glmnetfit objects but rather type = "response" everywhere and taking control of how the predictions are formatted at the end (aka turn probabilities into class predictions or not).

The corresponding tests are with the other glmnet-related tests in extratests: tidymodels/extratests#77

there objects result from using base-R families
because `type = "class"` is not available in glmnet for `glmnetfit` objects
if (type == "class") {
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = lvl)
pred <- pred %>%
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apart from this line adding the translation from probabilities to classes, the changes to this function are only to make use of dplyr and tidyr

Copy link
Contributor

@simonpcouch simonpcouch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do-er of diligence! Effectuator of exactness! Agent of accuracy!

This PR is super solid—no comments. Thanks for the helpful description.🍄

res <- switch(
model_type,
"linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty),
"logistic_reg" = format_glmnet_multi_logistic_reg(pred,
penalty = penalty,
type = dots$type,
type = type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note to help speed up @topepo's review: format_glmnet_multi_logistic_reg() used to take whichever type was passed along to glmnet, and its internals read:

if (type == "class") {
  # ...
} else {
  # ...
}

That helper is now supplied the "parsnip" type instead, in this case one of "class" or "prob", so that the above conditional has only one possible value for the else. The format_glmnet_multi_multinom_reg() helper does this already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonpcouch Thanks for adding this! 🙌

Merge commit '070f1b2ccf92ae95aa16dc34153c9aac874d6ccf'

#Conflicts:
#	NEWS.md
#	R/glmnet.R
@hfrick hfrick merged commit 92b2bd5 into main Mar 2, 2023
@hfrick hfrick deleted the glmnet-base-family branch March 2, 2023 10:44
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 17, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to tune penalty for glmnet with non-default family handle quasi-likelihood glmnet models better
3 participants