Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine runtime estimation for forward search #459

Merged
merged 2 commits into from Sep 30, 2023

Conversation

fweber144
Copy link
Collaborator

@fweber144 fweber144 commented Sep 30, 2023

This refines the estimation of the runtime of the forward search remaining after the projection onto the intercept-only submodel, in particular allowing for an interval estimate in case of multilevel and/or additive ("smooth") terms.

The factors used for scaling up the runtime estimate (coming from the intercept-only projection) were derived empirically as follows:

# Source for the data-generating mechanism and the reference model: Example
# section of `?rstanarm::stan_gamm4`.
set.seed(7456)
dat <- mgcv::gamSim(1, n = 200, scale = 2)
dat$fac <- fac <- as.factor(sample(1:20, 200, replace = TRUE))
dat$y <- dat$y + model.matrix(~ fac - 1) %*% rnorm(20) * 0.5
rfit <- rstanarm::stan_gamm4(
  y ~ s(x0) + x1 + s(x2),
  random = ~ (1 | fac),
  data = dat,
  cores = 4,
  seed = 1140350788,
  adapt_delta = 0.99,
  refresh = 0
)

# With projpred at commit cc0d3064:
devtools::load_all(".")
set.seed(8234467)
refm <- get_refmodel(rfit)
refd <- get_refdist(refm, nclusters = 20)

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = character(),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#     min       lq     mean  median       uq      max neval
# 13.9557 14.16054 17.08022 14.3025 14.82752 153.7715   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#      min       lq     mean   median       uq      max neval
# 17.81272 18.35631 19.42337 18.52197 18.82939 61.19784   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "(1 | fac)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#      min       lq     mean   median      uq      max neval
# 473.7028 489.0631 502.0515 497.7506 503.722 761.9663   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "s(x0)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#     min       lq     mean   median       uq      max neval
# 166.975 170.1548 179.7964 177.2802 185.0535 228.0465   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "(1 | fac)", "s(x0)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: seconds
#      min       lq     mean   median       uq     max neval
# 1.017806 1.043563 1.079292 1.061689 1.084159 1.35676   100
###

From these microbenchmark results, we obtain the following factors (I should have assigned the microbenchmark::microbenchmark() outputs to different objects instead of working with the hard-coded times here, but I was too lazy to re-run):

bm_empty <- c(13.9557, 14.16054, 17.08022, 14.3025, 14.82752, 153.7715)
bm_glm <- c(17.81272, 18.35631, 19.42337, 18.52197, 18.82939, 61.19784)
bm_glmm <- c(473.7028, 489.0631, 502.0515, 497.7506, 503.722, 761.9663)
bm_gam <- c(166.975, 170.1548, 179.7964, 177.2802, 185.0535, 228.0465)
bm_gamm <- 1e3 * c(1.017806, 1.043563, 1.079292, 1.061689, 1.084159, 1.35676)

bm_glm / bm_empty
# [1] 1.2763760 1.2963001 1.1371850 1.2950163 1.2698948 0.3979791
## --> ca. 1.3 from intercept-only to GLM submodel

bm_glmm / bm_glm
# [1] 26.59351 26.64278 25.84781 26.87352 26.75190 12.45087
## --> ca. 26.9 from GLM to GLMM submodel

bm_gam / bm_glm
# [1] 9.373919 9.269554 9.256705 9.571347 9.827907 3.726382
## --> ca. 9.8 from GLM to GAM submodel

bm_gamm / bm_glm
# [1] 57.13928 56.85037 55.56667 57.32052 57.57802 22.17006
## --> ca. 57.6 from GLM to GAMM submodel

in particular allowing for an interval estimate in case of multilevel and/or
additive ("smooth") terms.
@fweber144
Copy link
Collaborator Author

On my machine, the following example causes the runtime message to be displayed:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(6, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

rfit_train <- rstanarm::stan_glmer(
  y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10 + X11 + X12 + X13 + X14 +
    X15 + X16 + X17 + X18 + X19 + X20 + (1 | group),
  data = dat_train,
  cores = 4,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

# With projpred at commit c7b1d2d7:
devtools::load_all(".")
options(projpred.extra_verbose = TRUE)

d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
Sys.time()
vs <- varsel(rfit_train,
             d_test = d_test_list,
             refit_prj = FALSE,
             seed = 46782345)
Sys.time()

And the following example doesn't:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(6, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

rfit_train <- rstanarm::stan_glm(
  y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10 + X11 + X12 + X13 + X14 +
    X15 + X16 + X17 + X18 + X19 + X20,
  data = dat_train,
  cores = 4,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

# With projpred at commit c7b1d2d7:
devtools::load_all(".")
options(projpred.extra_verbose = TRUE)

d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
Sys.time()
vs <- varsel(rfit_train,
             d_test = d_test_list,
             refit_prj = FALSE,
             seed = 46782345)
Sys.time()

@fweber144 fweber144 merged commit b76e685 into stan-dev:master Sep 30, 2023
@fweber144 fweber144 deleted the timing_fw_search branch September 30, 2023 12:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant