Skip to content

Classes predicted by dbarts engine don't match prediced probabilities #780

@barnabywalker

Description

@barnabywalker

The problem

I'm having trouble with making classification predictions with bart, using the dbarts engine.

The predicted probabilities look like they're correct but the predicted classes look like they're switched around (e.g. p('pos') = 0.8 but the predicted class is 'neg').

I'm using a factor as the target and I have a feeling the problem might be a mismatch between how parsnip converts that to (0, 1) to feed into dbarts and how dbart_predict_calc sets the class on the predictions.

Is this a bug, or is there something I need to do to make sure the predicted class lines up with the probabilities?

Reproducible example

library(tidyverse)
#> Warning: package 'ggplot2' was built under R version 4.1.3
#> Warning: package 'tibble' was built under R version 4.1.3
#> Warning: package 'tidyr' was built under R version 4.1.3
#> Warning: package 'dplyr' was built under R version 4.1.3
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.1.3
#> Warning: package 'broom' was built under R version 4.1.3
#> Warning: package 'dials' was built under R version 4.1.3
#> Warning: package 'scales' was built under R version 4.1.3
#> Warning: package 'infer' was built under R version 4.1.3
#> Warning: package 'modeldata' was built under R version 4.1.3
#> Warning: package 'parsnip' was built under R version 4.1.3
#> Warning: package 'recipes' was built under R version 4.1.3
#> Warning: package 'rsample' was built under R version 4.1.3
#> Warning: package 'tune' was built under R version 4.1.3
#> Warning: package 'workflows' was built under R version 4.1.3
#> Warning: package 'workflowsets' was built under R version 4.1.3
#> Warning: package 'yardstick' was built under R version 4.1.3

d <- tibble(
  x1=c(rnorm(100, mean=0), rnorm(100, mean=5)),
  x2=c(rnorm(100, mean=0), rnorm(100, mean=5)),
  y=c(rep("negative", 100), rep("positive", 100))
)

d$y <- factor(d$y, levels=c("negative", "positive"))

model <- bart(mode="classification", engine="dbarts")

model_fit <- fit(model, y ~ ., data=d)

test <- tibble(
  x1=c(0, 5),
  x2=c(0, 5),
)

predict(model_fit, new_data=test, type="prob")
#> # A tibble: 2 x 2
#>   .pred_negative .pred_positive
#>            <dbl>          <dbl>
#> 1        0.998            0.002
#> 2        0.00500          0.995
predict(model_fit, new_data=test, type="class")
#> # A tibble: 2 x 1
#>   .pred_class
#>   <fct>      
#> 1 positive   
#> 2 negative

Created on 2022-08-05 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> - Session info  --------------------------------------------------------------
#>  hash: bowling, globe showing Americas, lotion bottle
#> 
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       Windows 10 x64 (build 22622)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_United Kingdom.1252
#>  ctype    English_United Kingdom.1252
#>  tz       Europe/London
#>  date     2022-08-05
#>  pandoc   2.18 @ C:/Program Files/RStudio/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> - Packages -------------------------------------------------------------------
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.1)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.2)
#>  broom        * 1.0.0      2022-07-01 [1] CRAN (R 4.1.3)
#>  cellranger     1.1.0      2016-07-27 [1] CRAN (R 4.1.1)
#>  class          7.3-19     2021-05-03 [1] CRAN (R 4.1.2)
#>  cli            3.3.0      2022-04-25 [1] CRAN (R 4.1.3)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.2)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.1.3)
#>  crayon         1.5.1      2022-03-26 [1] CRAN (R 4.1.3)
#>  dbarts         0.9-20     2021-10-08 [1] CRAN (R 4.1.1)
#>  DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
#>  dbplyr         2.1.1      2021-04-06 [1] CRAN (R 4.1.1)
#>  dials        * 1.0.0      2022-06-14 [1] CRAN (R 4.1.3)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.1)
#>  digest         0.6.28     2021-09-23 [1] CRAN (R 4.1.1)
#>  dplyr        * 1.0.9      2022-04-28 [1] CRAN (R 4.1.3)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.1)
#>  evaluate       0.15       2022-02-18 [1] CRAN (R 4.1.2)
#>  fansi          1.0.0      2022-01-10 [1] CRAN (R 4.1.2)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.1.1)
#>  forcats      * 0.5.1      2021-01-27 [1] CRAN (R 4.1.1)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.1.2)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.2)
#>  furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.1)
#>  future         1.25.0     2022-04-24 [1] CRAN (R 4.1.3)
#>  future.apply   1.9.0      2022-04-25 [1] CRAN (R 4.1.3)
#>  generics       0.1.2      2022-01-31 [1] CRAN (R 4.1.2)
#>  ggplot2      * 3.3.6      2022-05-03 [1] CRAN (R 4.1.3)
#>  globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.1)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.1.3)
#>  gower          0.2.2      2020-06-23 [1] CRAN (R 4.1.1)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.1)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.1)
#>  hardhat        1.2.0      2022-06-30 [1] CRAN (R 4.1.3)
#>  haven          2.4.3      2021-08-04 [1] CRAN (R 4.1.1)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.1.1)
#>  hms            1.1.1      2021-09-26 [1] CRAN (R 4.1.1)
#>  htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.1.1)
#>  httr           1.4.2      2020-07-20 [1] CRAN (R 4.1.1)
#>  infer        * 1.0.2      2022-06-10 [1] CRAN (R 4.1.3)
#>  ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.1)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.1.2)
#>  jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.1.2)
#>  knitr          1.39       2022-04-26 [1] CRAN (R 4.1.3)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.1)
#>  lhs            1.1.3      2021-09-08 [1] CRAN (R 4.1.1)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.1)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.1.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.1.3)
#>  MASS           7.3-54     2021-05-03 [1] CRAN (R 4.1.2)
#>  Matrix         1.3-4      2021-06-01 [1] CRAN (R 4.1.2)
#>  modeldata    * 1.0.0      2022-07-01 [1] CRAN (R 4.1.3)
#>  modelr         0.1.8      2020-05-19 [1] CRAN (R 4.1.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.1)
#>  nnet           7.3-16     2021-05-03 [1] CRAN (R 4.1.2)
#>  parallelly     1.31.1     2022-04-22 [1] CRAN (R 4.1.3)
#>  parsnip      * 1.0.0      2022-06-16 [1] CRAN (R 4.1.3)
#>  pillar         1.7.0      2022-02-01 [1] CRAN (R 4.1.2)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.1)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.1)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp           1.0.7      2021-07-07 [1] CRAN (R 4.1.1)
#>  readr        * 2.1.0      2021-11-11 [1] CRAN (R 4.1.2)
#>  readxl         1.3.1      2019-03-13 [1] CRAN (R 4.1.1)
#>  recipes      * 1.0.1      2022-07-07 [1] CRAN (R 4.1.3)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.1.1)
#>  rlang          1.0.4      2022-07-12 [1] CRAN (R 4.1.3)
#>  rmarkdown      2.14       2022-04-25 [1] CRAN (R 4.1.3)
#>  rpart          4.1-15     2019-04-12 [1] CRAN (R 4.1.2)
#>  rsample      * 1.0.0      2022-06-24 [1] CRAN (R 4.1.3)
#>  rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.1)
#>  rvest          1.0.2      2021-10-16 [1] CRAN (R 4.1.1)
#>  scales       * 1.2.0      2022-04-13 [1] CRAN (R 4.1.3)
#>  sessioninfo    1.2.1      2021-11-02 [1] CRAN (R 4.1.2)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.2)
#>  stringr      * 1.4.0      2019-02-10 [1] CRAN (R 4.1.1)
#>  survival       3.2-13     2021-08-24 [1] CRAN (R 4.1.2)
#>  tibble       * 3.1.8      2022-07-22 [1] CRAN (R 4.1.3)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.1.3)
#>  tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.1.3)
#>  tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.1.2)
#>  tidyverse    * 1.3.1      2021-04-15 [1] CRAN (R 4.1.1)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.1)
#>  tune         * 1.0.0      2022-07-07 [1] CRAN (R 4.1.3)
#>  tzdb           0.2.0      2021-10-27 [1] CRAN (R 4.1.1)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.1)
#>  vctrs          0.4.1      2022-04-13 [1] CRAN (R 4.1.3)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.1.3)
#>  workflows    * 1.0.0      2022-07-05 [1] CRAN (R 4.1.3)
#>  workflowsets * 1.0.0      2022-07-12 [1] CRAN (R 4.1.3)
#>  xfun           0.30       2022-03-02 [1] CRAN (R 4.1.3)
#>  xml2           1.3.3      2021-11-30 [1] CRAN (R 4.1.2)
#>  yaml           2.2.1      2020-02-01 [1] CRAN (R 4.1.1)
#>  yardstick    * 1.0.0      2022-06-06 [1] CRAN (R 4.1.3)
#> 
#>  [1] C:/Users/bw42kg/Documents/R/R-4.1.2/library
#> 
#> ------------------------------------------------------------------------------

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions