Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure on predicting class prob from `multinom_reg()` with penalty #234

Closed
DavisVaughan opened this issue Nov 11, 2019 · 1 comment
Closed

Comments

@DavisVaughan
Copy link
Collaborator

@DavisVaughan DavisVaughan commented Nov 11, 2019

library(parsnip)

mod <- multinom_reg(penalty = 0.01)
mod <- set_engine(mod, "glmnet")

fit <- fit(mod, Species ~ Sepal.Length + Sepal.Width, iris)

predict(fit, iris, type = "prob")
#> Error: `predict()` doesn't work with multiple penalties (i.e. lambdas). Please specify a single value using `penalty = some_value` or use `multi_predict()` to get multiple predictions per row of data.

Created on 2019-11-11 by the reprex package (v0.3.0.9000)

The failure is in this part of parsnip:::predict_classprob.model_fit()

new_data <- object$spec$method$pred$prob$pre(new_data, object)

topepo added a commit that referenced this issue Dec 2, 2019
@topepo
Copy link
Collaborator

@topepo topepo commented Dec 2, 2019

Done!

library(parsnip)

mod <- multinom_reg(penalty = 0.01)
mod <- set_engine(mod, "glmnet")

fit <- fit(mod, Species ~ Sepal.Length + Sepal.Width, iris)

predict(fit, iris, type = "prob")
#> # A tibble: 150 x 3
#>    .pred_setosa .pred_versicolor .pred_virginica
#>           <dbl>            <dbl>           <dbl>
#>  1        0.978          0.0181         0.00390 
#>  2        0.872          0.114          0.0136  
#>  3        0.984          0.0146         0.00140 
#>  4        0.983          0.0160         0.00124 
#>  5        0.993          0.00619        0.00120 
#>  6        0.989          0.00767        0.00335 
#>  7        0.997          0.00277        0.000250
#>  8        0.977          0.0198         0.00344 
#>  9        0.980          0.0192         0.000962
#> 10        0.925          0.0669         0.00843 
#> # … with 140 more rows

Created on 2019-12-01 by the reprex package (v0.3.0)

@topepo topepo closed this Dec 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
2 participants
You can’t perform that action at this time.