Skip to content

Conversation

@qiushiyan
Copy link
Contributor

@qiushiyan qiushiyan commented Aug 9, 2022

Fix #780

@simonpcouch
Copy link
Contributor

simonpcouch commented Aug 15, 2022

Looks great!

A quick check that this prediction is level-independent:

library(parsnip)

d <- data.frame(
  x1 = c(rnorm(100, mean = 0), rnorm(100, mean = 5)),
  x2 = c(rnorm(100, mean = 0), rnorm(100, mean = 5)),
  y = c(rep("negative", 100), rep("positive", 100))
)

d$y <- factor(d$y, levels = c("negative", "positive"))

model <- bart(mode = "classification", engine = "dbarts")

model_fit <- fit(model, y ~ ., data=d)

test <- data.frame(
  x1 = c(0, 5),
  x2 = c(0, 5)
)

predict(model_fit, new_data = test, type = "prob")
#> # A tibble: 2 × 2
#>   .pred_negative .pred_positive
#>            <dbl>          <dbl>
#> 1        0.992            0.008
#> 2        0.00600          0.994

predict(model_fit, new_data = test, type = "class")
#> # A tibble: 2 × 1
#>   .pred_class
#>   <fct>      
#> 1 negative   
#> 2 positive

# re-order levels to ensure independence
d$y <- factor(d$y, levels = c("positive", "negative"))

model <- bart(mode = "classification", engine = "dbarts")

model_fit <- fit(model, y ~ ., data = d)

predict(model_fit, new_data = test, type = "prob")
#> # A tibble: 2 × 2
#>   .pred_positive .pred_negative
#>            <dbl>          <dbl>
#> 1        0.00500          0.995
#> 2        0.995            0.005

predict(model_fit, new_data = test, type = "class")
#> # A tibble: 2 × 1
#>   .pred_class
#>   <fct>      
#> 1 negative   
#> 2 positive

Created on 2022-08-15 by the reprex package (v2.0.1)

@simonpcouch simonpcouch merged commit de88664 into main Aug 15, 2022
@simonpcouch simonpcouch deleted the dbarts-prediction branch August 15, 2022 20:45
@github-actions
Copy link
Contributor

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Aug 30, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Classes predicted by dbarts engine don't match prediced probabilities

3 participants