Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the required structure for cvfits #456

Merged
merged 4 commits into from
Sep 25, 2023

Conversation

fweber144
Copy link
Collaborator

This simplifies the required structure for the object passed to argument cvfits of init_refmodel() (see the NEWS.md entry added here for details). The reason for this change is that loo::kfold() output can't be used straightforwardly anyway:

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
n_strat <- 4L
dat$strat_fac <- gl(n = n_strat, k = floor(nrow(dat) / n_strat),
                    length = nrow(dat), labels = paste0("gr", seq_len(n_strat)))
library(rstanarm)
refm_fit <- stan_glm(y ~ X1 + X2 + X3 + X4 + X5,
                     data = dat,
                     chains = 1,
                     iter = 500,
                     seed = 1140350788,
                     refresh = 0)
set.seed(3424511)
refm_kfold <- kfold(
  refm_fit,
  folds = loo::kfold_split_stratified(K = 10, x = dat$strat_fac),
  save_fits = TRUE,
  cores = 1
)
cvfits_crr <- structure(
  list(fits = refm_kfold$fits[, "fit"]),
  folds = sapply(seq_len(nrow(dat)), function(ii) {
    which(sapply(refm_kfold$fits[, "omitted"], "%in%", x = ii))
  })
)
length(refm_kfold$fits)
## --> 20
length(cvfits_crr$fits)
## --> 10
lapply(refm_kfold$fits, class)
## --> First 10 are `stanreg`s, last 10 are `integer` vectors.

so having the K reference model refits in a sub-list called fits is not necessary and only complicates things.

(the content from the sub-`list` called `fits` can be moved one level up into
`cvfits` directly, because `loo::kfold()` output can't be used straightforwardly
anyway:
```{r, eval=FALSE}
data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
n_strat <- 4L
dat$strat_fac <- gl(n = n_strat, k = floor(nrow(dat) / n_strat),
                    length = nrow(dat), labels = paste0("gr", seq_len(n_strat)))
library(rstanarm)
refm_fit <- stan_glm(y ~ X1 + X2 + X3 + X4 + X5,
                     data = dat,
                     chains = 1,
                     iter = 500,
                     seed = 1140350788,
                     refresh = 0)
set.seed(3424511)
refm_kfold <- kfold(
  refm_fit,
  folds = loo::kfold_split_stratified(K = 10, x = dat$strat_fac),
  save_fits = TRUE,
  cores = 1
)
cvfits_crr <- structure(
  list(fits = refm_kfold$fits[, "fit"]),
  folds = sapply(seq_len(nrow(dat)), function(ii) {
    which(sapply(refm_kfold$fits[, "omitted"], "%in%", x = ii))
  })
)
length(refm_kfold$fits)
length(cvfits_crr$fits)
lapply(refm_kfold$fits, class)
```
).
@fweber144 fweber144 merged commit e4eb097 into stan-dev:master Sep 25, 2023
@fweber144 fweber144 deleted the enhance_kfold branch September 25, 2023 18:59
fweber144 added a commit that referenced this pull request Nov 22, 2023
…t correct

(attribute `folds` was dropped).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant