Skip to content

Inconsistency with return type for probability type predictions for two class GAM models  #708

@oj713

Description

@oj713

The problem

I'm encountering an inconsistency with generating probability predictions for two class GAM models. The predictions are being returned as arrays within a tibble, rather than just a tibble of numeric values. I believe this issue arises because mgcv returns an array of predictions, rather than a vector.

Issue #541 covers a similar issue around GAM class predictions where mgcv's prediction format was also the source of the problem. I believe the fix for this problem would closely mirror the fix for that issue (fix #542).

Reproducible example

library(parsnip)

set.seed(33)
twoClassSim <- caret::twoClassSim()

gam <- gen_additive_mod() |>
  set_mode("classification") |>
  fit(Class ~ Linear01, 
      data = twoClassSim)

pred <- predict(gam, head(twoClassSim), type = "prob")

# demonstrating error
class(pred$.pred_Class1)
#> [1] "array"
dplyr::glimpse(pred)
#> Rows: 6
#> Columns: 2
#> $ .pred_Class1 <dbl> <array[6]>
#> $ .pred_Class2 <dbl> <array[6]>

# session information 
sessionInfo()
#> R version 4.2.0 (2022-04-22)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Monterey 12.3.1
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] parsnip_0.2.1
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.8.3         lubridate_1.8.0      lattice_0.20-45     
#>  [4] listenv_0.8.0        tidyr_1.2.0          class_7.3-20        
#>  [7] digest_0.6.29        ipred_0.9-12         foreach_1.5.2       
#> [10] utf8_1.2.2           parallelly_1.31.1    R6_2.5.1            
#> [13] plyr_1.8.7           stats4_4.2.0         reprex_2.0.1        
#> [16] hardhat_0.2.0        evaluate_0.15        ggplot2_3.3.5       
#> [19] highr_0.9            pillar_1.7.0         rlang_1.0.2         
#> [22] caret_6.0-92         data.table_1.14.2    rstudioapi_0.13     
#> [25] rpart_4.1.16         Matrix_1.4-1         rmarkdown_2.14      
#> [28] splines_4.2.0        gower_1.0.0          stringr_1.4.0       
#> [31] munsell_0.5.0        compiler_4.2.0       xfun_0.30           
#> [34] pkgconfig_2.0.3      mgcv_1.8-40          globals_0.14.0      
#> [37] htmltools_0.5.2      nnet_7.3-17          tidyselect_1.1.2    
#> [40] tibble_3.1.6         prodlim_2019.11.13   codetools_0.2-18    
#> [43] fansi_1.0.3          future_1.25.0        crayon_1.5.1        
#> [46] dplyr_1.0.9          withr_2.5.0          ModelMetrics_1.2.2.2
#> [49] MASS_7.3-56          recipes_0.2.0        grid_4.2.0          
#> [52] nlme_3.1-157         gtable_0.3.0         lifecycle_1.0.1     
#> [55] magrittr_2.0.3       pROC_1.18.0          scales_1.2.0        
#> [58] future.apply_1.9.0   cli_3.3.0            stringi_1.7.6       
#> [61] reshape2_1.4.4       fs_1.5.2             timeDate_3043.102   
#> [64] ellipsis_0.3.2       generics_0.1.2       vctrs_0.4.1         
#> [67] lava_1.6.10          iterators_1.0.14     tools_4.2.0         
#> [70] glue_1.6.2           purrr_0.3.4          parallel_4.2.0      
#> [73] fastmap_1.1.0        survival_3.3-1       yaml_2.3.5          
#> [76] colorspace_2.0-3     knitr_1.39

Created on 2022-04-28 by the reprex package (v2.0.1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions