Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency with return type for probability type predictions for two class GAM models #708

Closed
oj713 opened this issue Apr 28, 2022 · 5 comments

Comments

@oj713
Copy link
Contributor

oj713 commented Apr 28, 2022

The problem

I'm encountering an inconsistency with generating probability predictions for two class GAM models. The predictions are being returned as arrays within a tibble, rather than just a tibble of numeric values. I believe this issue arises because mgcv returns an array of predictions, rather than a vector.

Issue #541 covers a similar issue around GAM class predictions where mgcv's prediction format was also the source of the problem. I believe the fix for this problem would closely mirror the fix for that issue (fix #542).

Reproducible example

library(parsnip)

set.seed(33)
twoClassSim <- caret::twoClassSim()

gam <- gen_additive_mod() |>
  set_mode("classification") |>
  fit(Class ~ Linear01, 
      data = twoClassSim)

pred <- predict(gam, head(twoClassSim), type = "prob")

# demonstrating error
class(pred$.pred_Class1)
#> [1] "array"
dplyr::glimpse(pred)
#> Rows: 6
#> Columns: 2
#> $ .pred_Class1 <dbl> <array[6]>
#> $ .pred_Class2 <dbl> <array[6]>

# session information 
sessionInfo()
#> R version 4.2.0 (2022-04-22)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Monterey 12.3.1
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] parsnip_0.2.1
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.8.3         lubridate_1.8.0      lattice_0.20-45     
#>  [4] listenv_0.8.0        tidyr_1.2.0          class_7.3-20        
#>  [7] digest_0.6.29        ipred_0.9-12         foreach_1.5.2       
#> [10] utf8_1.2.2           parallelly_1.31.1    R6_2.5.1            
#> [13] plyr_1.8.7           stats4_4.2.0         reprex_2.0.1        
#> [16] hardhat_0.2.0        evaluate_0.15        ggplot2_3.3.5       
#> [19] highr_0.9            pillar_1.7.0         rlang_1.0.2         
#> [22] caret_6.0-92         data.table_1.14.2    rstudioapi_0.13     
#> [25] rpart_4.1.16         Matrix_1.4-1         rmarkdown_2.14      
#> [28] splines_4.2.0        gower_1.0.0          stringr_1.4.0       
#> [31] munsell_0.5.0        compiler_4.2.0       xfun_0.30           
#> [34] pkgconfig_2.0.3      mgcv_1.8-40          globals_0.14.0      
#> [37] htmltools_0.5.2      nnet_7.3-17          tidyselect_1.1.2    
#> [40] tibble_3.1.6         prodlim_2019.11.13   codetools_0.2-18    
#> [43] fansi_1.0.3          future_1.25.0        crayon_1.5.1        
#> [46] dplyr_1.0.9          withr_2.5.0          ModelMetrics_1.2.2.2
#> [49] MASS_7.3-56          recipes_0.2.0        grid_4.2.0          
#> [52] nlme_3.1-157         gtable_0.3.0         lifecycle_1.0.1     
#> [55] magrittr_2.0.3       pROC_1.18.0          scales_1.2.0        
#> [58] future.apply_1.9.0   cli_3.3.0            stringi_1.7.6       
#> [61] reshape2_1.4.4       fs_1.5.2             timeDate_3043.102   
#> [64] ellipsis_0.3.2       generics_0.1.2       vctrs_0.4.1         
#> [67] lava_1.6.10          iterators_1.0.14     tools_4.2.0         
#> [70] glue_1.6.2           purrr_0.3.4          parallel_4.2.0      
#> [73] fastmap_1.1.0        survival_3.3-1       yaml_2.3.5          
#> [76] colorspace_2.0-3     knitr_1.39

Created on 2022-04-28 by the reprex package (v2.0.1)

@oj713
Copy link
Contributor Author

oj713 commented Apr 28, 2022

I forked the parsnip repository and created a possible fix, available here.

@EmilHvitfeldt
Copy link
Member

Hello @oj713 That looks great, good find! If you want to, we would be happy if you would want to create a PR, if not them I'll take it from here.

@oj713
Copy link
Contributor Author

oj713 commented Apr 28, 2022

Hi @EmilHvitfeldt, thank you! I'll go ahead and create a PR now.

@EmilHvitfeldt
Copy link
Member

Closed via #709

@github-actions
Copy link

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators May 14, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants