Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

augment() for models without class probabilities #487

Merged
merged 5 commits into from
May 13, 2021
Merged

Conversation

juliasilge
Copy link
Member

@juliasilge juliasilge commented May 11, 2021

Closes #435

This PR adds a tryCatch() for making the type = "prob" predictions in augment(), so that it only returns the class predictions if the class probability predictions fail:

library(parsnip)

svm_spec <- svm_linear() %>%
  set_mode("classification") %>%
  set_engine("LiblineaR")

svm_fit <- fit(svm_spec, Species ~ ., data = iris)

augment(svm_fit, iris)
#> # A tibble: 150 x 6
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species .pred_class
#>           <dbl>       <dbl>        <dbl>       <dbl> <fct>   <fct>      
#>  1          5.1         3.5          1.4         0.2 setosa  setosa     
#>  2          4.9         3            1.4         0.2 setosa  setosa     
#>  3          4.7         3.2          1.3         0.2 setosa  setosa     
#>  4          4.6         3.1          1.5         0.2 setosa  setosa     
#>  5          5           3.6          1.4         0.2 setosa  setosa     
#>  6          5.4         3.9          1.7         0.4 setosa  setosa     
#>  7          4.6         3.4          1.4         0.3 setosa  setosa     
#>  8          5           3.4          1.5         0.2 setosa  setosa     
#>  9          4.4         2.9          1.4         0.2 setosa  setosa     
#> 10          4.9         3.1          1.5         0.1 setosa  setosa     
#> # … with 140 more rows

data(two_class_dat, package = "modeldata")

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = two_class_dat)

augment(cls_form, two_class_dat)
#> # A tibble: 791 x 6
#>        A     B Class  .pred_class .pred_Class1 .pred_Class2
#>    <dbl> <dbl> <fct>  <fct>              <dbl>        <dbl>
#>  1  2.07 1.63  Class1 Class1             0.513      0.487  
#>  2  2.02 1.04  Class1 Class1             0.906      0.0940 
#>  3  1.69 1.37  Class2 Class1             0.645      0.355  
#>  4  3.43 1.98  Class2 Class1             0.599      0.401  
#>  5  2.88 1.98  Class1 Class2             0.435      0.565  
#>  6  3.31 2.41  Class2 Class2             0.201      0.799  
#>  7  2.50 1.56  Class2 Class1             0.701      0.299  
#>  8  1.98 1.55  Class2 Class1             0.563      0.437  
#>  9  2.88 0.580 Class1 Class1             0.994      0.00622
#> 10  3.74 2.74  Class2 Class2             0.105      0.895  
#> # … with 781 more rows

Created on 2021-05-11 by the reprex package (v2.0.0)

@juliasilge juliasilge changed the title Try catch augment augment() for models without class probabilities May 11, 2021
@juliasilge juliasilge requested a review from topepo May 11, 2021 22:04
@juliasilge
Copy link
Member Author

This new change ⬆️ will fail until #489 is merged in.

@topepo topepo merged commit 6d6b0a7 into master May 13, 2021
@topepo topepo deleted the try-catch-augment branch May 13, 2021 16:48
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators May 28, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

augment.model_fit throws error if classification model doesn't support class probabilities
2 participants