-
Notifications
You must be signed in to change notification settings - Fork 106
Closed
Description
The problem
I'm having trouble with making classification predictions with bart, using the dbarts engine.
The predicted probabilities look like they're correct but the predicted classes look like they're switched around (e.g. p('pos') = 0.8 but the predicted class is 'neg').
I'm using a factor as the target and I have a feeling the problem might be a mismatch between how parsnip converts that to (0, 1) to feed into dbarts and how dbart_predict_calc sets the class on the predictions.
Is this a bug, or is there something I need to do to make sure the predicted class lines up with the probabilities?
Reproducible example
library(tidyverse)
#> Warning: package 'ggplot2' was built under R version 4.1.3
#> Warning: package 'tibble' was built under R version 4.1.3
#> Warning: package 'tidyr' was built under R version 4.1.3
#> Warning: package 'dplyr' was built under R version 4.1.3
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.1.3
#> Warning: package 'broom' was built under R version 4.1.3
#> Warning: package 'dials' was built under R version 4.1.3
#> Warning: package 'scales' was built under R version 4.1.3
#> Warning: package 'infer' was built under R version 4.1.3
#> Warning: package 'modeldata' was built under R version 4.1.3
#> Warning: package 'parsnip' was built under R version 4.1.3
#> Warning: package 'recipes' was built under R version 4.1.3
#> Warning: package 'rsample' was built under R version 4.1.3
#> Warning: package 'tune' was built under R version 4.1.3
#> Warning: package 'workflows' was built under R version 4.1.3
#> Warning: package 'workflowsets' was built under R version 4.1.3
#> Warning: package 'yardstick' was built under R version 4.1.3
d <- tibble(
x1=c(rnorm(100, mean=0), rnorm(100, mean=5)),
x2=c(rnorm(100, mean=0), rnorm(100, mean=5)),
y=c(rep("negative", 100), rep("positive", 100))
)
d$y <- factor(d$y, levels=c("negative", "positive"))
model <- bart(mode="classification", engine="dbarts")
model_fit <- fit(model, y ~ ., data=d)
test <- tibble(
x1=c(0, 5),
x2=c(0, 5),
)
predict(model_fit, new_data=test, type="prob")
#> # A tibble: 2 x 2
#> .pred_negative .pred_positive
#> <dbl> <dbl>
#> 1 0.998 0.002
#> 2 0.00500 0.995
predict(model_fit, new_data=test, type="class")
#> # A tibble: 2 x 1
#> .pred_class
#> <fct>
#> 1 positive
#> 2 negativeCreated on 2022-08-05 by the reprex package (v2.0.1)
Session info
sessioninfo::session_info()
#> - Session info --------------------------------------------------------------
#> hash: bowling, globe showing Americas, lotion bottle
#>
#> setting value
#> version R version 4.1.2 (2021-11-01)
#> os Windows 10 x64 (build 22622)
#> system x86_64, mingw32
#> ui RTerm
#> language (EN)
#> collate English_United Kingdom.1252
#> ctype English_United Kingdom.1252
#> tz Europe/London
#> date 2022-08-05
#> pandoc 2.18 @ C:/Program Files/RStudio/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> - Packages -------------------------------------------------------------------
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.1.1)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.2)
#> broom * 1.0.0 2022-07-01 [1] CRAN (R 4.1.3)
#> cellranger 1.1.0 2016-07-27 [1] CRAN (R 4.1.1)
#> class 7.3-19 2021-05-03 [1] CRAN (R 4.1.2)
#> cli 3.3.0 2022-04-25 [1] CRAN (R 4.1.3)
#> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.2)
#> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.1.3)
#> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.1.3)
#> dbarts 0.9-20 2021-10-08 [1] CRAN (R 4.1.1)
#> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.1.2)
#> dbplyr 2.1.1 2021-04-06 [1] CRAN (R 4.1.1)
#> dials * 1.0.0 2022-06-14 [1] CRAN (R 4.1.3)
#> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.1.1)
#> digest 0.6.28 2021-09-23 [1] CRAN (R 4.1.1)
#> dplyr * 1.0.9 2022-04-28 [1] CRAN (R 4.1.3)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.1)
#> evaluate 0.15 2022-02-18 [1] CRAN (R 4.1.2)
#> fansi 1.0.0 2022-01-10 [1] CRAN (R 4.1.2)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.1)
#> forcats * 0.5.1 2021-01-27 [1] CRAN (R 4.1.1)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.1.2)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.2)
#> furrr 0.2.3 2021-06-25 [1] CRAN (R 4.1.1)
#> future 1.25.0 2022-04-24 [1] CRAN (R 4.1.3)
#> future.apply 1.9.0 2022-04-25 [1] CRAN (R 4.1.3)
#> generics 0.1.2 2022-01-31 [1] CRAN (R 4.1.2)
#> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.1.3)
#> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.1)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.1.3)
#> gower 0.2.2 2020-06-23 [1] CRAN (R 4.1.1)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.1.1)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.1.1)
#> hardhat 1.2.0 2022-06-30 [1] CRAN (R 4.1.3)
#> haven 2.4.3 2021-08-04 [1] CRAN (R 4.1.1)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.1.1)
#> hms 1.1.1 2021-09-26 [1] CRAN (R 4.1.1)
#> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.1)
#> httr 1.4.2 2020-07-20 [1] CRAN (R 4.1.1)
#> infer * 1.0.2 2022-06-10 [1] CRAN (R 4.1.3)
#> ipred 0.9-12 2021-09-15 [1] CRAN (R 4.1.1)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.1.2)
#> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.1.2)
#> knitr 1.39 2022-04-26 [1] CRAN (R 4.1.3)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.2)
#> lava 1.6.10 2021-09-02 [1] CRAN (R 4.1.1)
#> lhs 1.1.3 2021-09-08 [1] CRAN (R 4.1.1)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.1)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.1)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.1.1)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.1.3)
#> MASS 7.3-54 2021-05-03 [1] CRAN (R 4.1.2)
#> Matrix 1.3-4 2021-06-01 [1] CRAN (R 4.1.2)
#> modeldata * 1.0.0 2022-07-01 [1] CRAN (R 4.1.3)
#> modelr 0.1.8 2020-05-19 [1] CRAN (R 4.1.1)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.1.1)
#> nnet 7.3-16 2021-05-03 [1] CRAN (R 4.1.2)
#> parallelly 1.31.1 2022-04-22 [1] CRAN (R 4.1.3)
#> parsnip * 1.0.0 2022-06-16 [1] CRAN (R 4.1.3)
#> pillar 1.7.0 2022-02-01 [1] CRAN (R 4.1.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.1)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.1.1)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.1.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.1)
#> Rcpp 1.0.7 2021-07-07 [1] CRAN (R 4.1.1)
#> readr * 2.1.0 2021-11-11 [1] CRAN (R 4.1.2)
#> readxl 1.3.1 2019-03-13 [1] CRAN (R 4.1.1)
#> recipes * 1.0.1 2022-07-07 [1] CRAN (R 4.1.3)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.1)
#> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.1.3)
#> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.1.3)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.1.2)
#> rsample * 1.0.0 2022-06-24 [1] CRAN (R 4.1.3)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.1)
#> rvest 1.0.2 2021-10-16 [1] CRAN (R 4.1.1)
#> scales * 1.2.0 2022-04-13 [1] CRAN (R 4.1.3)
#> sessioninfo 1.2.1 2021-11-02 [1] CRAN (R 4.1.2)
#> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.2)
#> stringr * 1.4.0 2019-02-10 [1] CRAN (R 4.1.1)
#> survival 3.2-13 2021-08-24 [1] CRAN (R 4.1.2)
#> tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.1.3)
#> tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.1.3)
#> tidyr * 1.2.0 2022-02-01 [1] CRAN (R 4.1.3)
#> tidyselect 1.1.2 2022-02-21 [1] CRAN (R 4.1.2)
#> tidyverse * 1.3.1 2021-04-15 [1] CRAN (R 4.1.1)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.1.1)
#> tune * 1.0.0 2022-07-07 [1] CRAN (R 4.1.3)
#> tzdb 0.2.0 2021-10-27 [1] CRAN (R 4.1.1)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.1)
#> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.1.3)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.1.3)
#> workflows * 1.0.0 2022-07-05 [1] CRAN (R 4.1.3)
#> workflowsets * 1.0.0 2022-07-12 [1] CRAN (R 4.1.3)
#> xfun 0.30 2022-03-02 [1] CRAN (R 4.1.3)
#> xml2 1.3.3 2021-11-30 [1] CRAN (R 4.1.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.1.1)
#> yardstick * 1.0.0 2022-06-06 [1] CRAN (R 4.1.3)
#>
#> [1] C:/Users/bw42kg/Documents/R/R-4.1.2/library
#>
#> ------------------------------------------------------------------------------Metadata
Metadata
Assignees
Labels
No labels