Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

glmnet multi_predict(): Check type for all modes #900

Merged
merged 5 commits into from Mar 14, 2023

Conversation

hfrick
Copy link
Member

@hfrick hfrick commented Mar 2, 2023

Since multi_predict() calls predict() with type = "raw", the type provided to multi_predict() did not get checked consistently.

This PR improves checks on the type by extending it to all modes, not just "classification". Closes #517

library(parsnip)
data(Chicago, package = "modeldata")

lm_spec <- linear_reg(penalty = 0.1) %>% set_engine("glmnet")
lm_fit <- fit(lm_spec, ridership ~ Clark_Lake + Quincy_Wells, data = Chicago)

multi_predict(lm_fit, Chicago[1:6,], penalty = c(0.05, 0.1), type = "class")
#> Error in `check_pred_type()` at parsnip/R/glmnet-engines.R:176:2:
#> ! For class predictions, the object should be a classification model.

Created on 2023-03-02 with reprex v2.0.2

@hfrick hfrick requested a review from simonpcouch March 6, 2023 17:30
Copy link
Contributor

@simonpcouch simonpcouch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! This is definitely an improvement.

@hfrick hfrick merged commit d8f273c into main Mar 14, 2023
9 checks passed
@hfrick hfrick deleted the glmnet-multi_predict-type-check branch March 14, 2023 16:59
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 30, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

multi_predict._elnet() doesn't error on inappropriate type like type = "class"
2 participants