Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int_conformal_cv doesn't work with group_vfold_cv #141

Closed
Jeffrothschild opened this issue Apr 2, 2024 · 1 comment · Fixed by #148
Closed

int_conformal_cv doesn't work with group_vfold_cv #141

Jeffrothschild opened this issue Apr 2, 2024 · 1 comment · Fixed by #148

Comments

@Jeffrothschild
Copy link

The problem

I'm having trouble using int_conformal_cv when the resamples were created from group_vfold_cv()

Here is an example I modified from https://www.tidymodels.org/learn/models/conformal-regression/#using-resampling-results

Reproducible example

make_data <- function(n, std_dev = 1 / 5) {
  tibble(x = runif(n, min = -1)) %>%
    mutate(
      y = (x^3) + 2 * exp(-6 * (x - 0.3)^2),
      y = y + rnorm(n, sd = std_dev)
    )
}

n <- 1000
set.seed(8383)
train_data <- make_data(n) %>% 
  mutate(color = sample(c('red', 'blue'), n(), replace = TRUE))

set.seed(7292)
test_data <- make_data(10000) %>% 
  mutate(color = sample(c('red', 'blue'), n(), replace = TRUE))

set.seed(493)
folds <- vfold_cv(train_data)
group_folds <- group_vfold_cv(train_data, group = color)

set.seed(484)
nnet_wflow <- 
  workflow(y ~ x, mlp(hidden_units = 4) %>% set_mode("regression"))

ctrl <- control_resamples(save_pred = TRUE, extract = I)

nnet_rs <- 
  nnet_wflow %>% 
  fit_resamples(folds, control = ctrl)

collect_metrics(nnet_rs)

# This works

cv_int <- int_conformal_cv(nnet_rs)

predict(cv_int, test_data, level = 0.90) %>% bind_cols(test_data)

# try again with group cv and get an error

group_nnet_rs <- 
  nnet_wflow %>% 
  fit_resamples(group_folds, control = ctrl)

collect_metrics(group_nnet_rs)

group_cv_int <- int_conformal_cv(group_nnet_rs)

# Error in if (rs$att$class != "vfold_cv") { : the condition has length > 1


reprex::reprex(si = TRUE)
@topepo
Copy link
Member

topepo commented May 30, 2024

There's a PR to fix it. Thanks for reporting this.

It will still issue a warning though.

topepo added a commit that referenced this issue May 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants