Skip to content

Commit

Permalink
Add a new transform argument to support SBC
Browse files Browse the repository at this point in the history
  • Loading branch information
wlandau-lilly committed Dec 15, 2022
1 parent e079fc6 commit f37eb66
Show file tree
Hide file tree
Showing 25 changed files with 536 additions and 25 deletions.
7 changes: 6 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,22 @@ Imports:
withr (>= 2.1.2)
Suggests:
dplyr (>= 1.0.2),
ggplot2 (>= 3.0.0),
knitr (>= 1.30),
purrr (>= 0.3.0),
R.utils (>= 2.10.1),
rmarkdown (>= 2.3),
SBC (>= 0.2.0),
testthat (>= 3.0.0),
tidyr (>= 1.0.0),
visNetwork (>= 2.0.9)
Remotes:
hyunjimoon/SBC,
stan-dev/cmdstanr,
SystemRequirements: CmdStan >= 2.25.0
Encoding: UTF-8
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
VignetteBuilder: knitr
Config/testthat/edition: 3
17 changes: 12 additions & 5 deletions R/tar_stan_gq_rep.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ tar_stan_gq_rep <- function(
variables = NULL,
summaries = NULL,
summary_args = NULL,
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand All @@ -66,6 +67,7 @@ tar_stan_gq_rep <- function(
targets::tar_assert_chr(stan_files)
targets::tar_assert_unique(stan_files)
lapply(stan_files, assert_stan_file)
assert_transform(transform)
name_stan <- produce_stan_names(stan_files)
name_file <- paste0(name, "_file")
name_lines <- paste0(name, "_lines")
Expand Down Expand Up @@ -120,7 +122,8 @@ tar_stan_gq_rep <- function(
data_copy = data_copy,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
command <- as.expression(as.call(args))
pattern_data <- substitute(map(x), env = list(x = sym_batch))
Expand Down Expand Up @@ -275,7 +278,8 @@ tar_stan_gq_rep_run <- function(
data_copy,
variables,
summaries,
summary_args
summary_args,
transform
) {
if (!is.null(stdout)) {
withr::local_output_sink(new = stdout, append = TRUE)
Expand Down Expand Up @@ -318,7 +322,8 @@ tar_stan_gq_rep_run <- function(
data_copy = data_copy,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
)
out$.file <- stan_path
Expand All @@ -340,7 +345,8 @@ tar_stan_gq_rep_run_rep <- function(
data_copy,
variables,
summaries,
summary_args
summary_args,
transform
) {
stan_seed <- data$.seed + 1L
stan_seed <- if_any(is.null(seed), stan_seed, stan_seed + seed)
Expand All @@ -367,6 +373,7 @@ tar_stan_gq_rep_run_rep <- function(
inc_warmup = NULL,
data = data,
data_copy = data_copy,
seed = stan_seed
seed = stan_seed,
transform = transform
)
}
2 changes: 2 additions & 0 deletions R/tar_stan_gq_rep_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ tar_stan_gq_rep_draws <- function(
threads_per_chain = NULL,
variables = NULL,
data_copy = character(0),
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand Down Expand Up @@ -138,6 +139,7 @@ tar_stan_gq_rep_draws <- function(
variables = variables,
summaries = NULL,
summary_args = NULL,
transform = substitute(transform),
tidy_eval = tidy_eval,
packages = packages,
library = library,
Expand Down
21 changes: 17 additions & 4 deletions R/tar_stan_mcmc_rep.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tar_stan_mcmc_rep <- function(
inc_warmup = FALSE,
summaries = NULL,
summary_args = NULL,
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand All @@ -115,6 +116,7 @@ tar_stan_mcmc_rep <- function(
targets::tar_assert_unique(stan_files, "stan_files must be unique")
targets::tar_assert_chr(data_copy, "data_copy must be a character vector")
lapply(stan_files, assert_stan_file)
assert_transform(transform)
name_stan <- produce_stan_names(stan_files)
name_file <- paste0(name, "_file")
name_lines <- paste0(name, "_lines")
Expand Down Expand Up @@ -193,7 +195,8 @@ tar_stan_mcmc_rep <- function(
data_copy = data_copy,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
command <- as.expression(as.call(args))
pattern_data <- substitute(map(x), env = list(x = sym_batch))
Expand Down Expand Up @@ -333,6 +336,12 @@ tar_stan_mcmc_rep <- function(
#' element of your Stan data list with names and dimensions corresponding
#' to those of the model. For details, read
#' <https://docs.ropensci.org/stantargets/articles/simulation.html>.
#' @param transform Symbol or `NULL`, name of a function that accepts
#' arguments `data` and `draws` and returns a data frame. Here,
#' `data` is the JAGS data list supplied to the model, and `draws`
#' is a data frame with one column per model parameter and one row
#' per posterior sample. See the simulation-based calibration (SBC)
#' section of the simulation vignette for an example.
tar_stan_mcmc_rep_run <- function(
stan_file,
stan_name,
Expand Down Expand Up @@ -382,7 +391,8 @@ tar_stan_mcmc_rep_run <- function(
inc_warmup,
variables,
summaries,
summary_args
summary_args,
transform
) {
if (!is.null(stdout)) {
withr::local_output_sink(new = stdout, append = TRUE)
Expand Down Expand Up @@ -449,7 +459,8 @@ tar_stan_mcmc_rep_run <- function(
inc_warmup = inc_warmup,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
)
out$.file <- stan_path
Expand Down Expand Up @@ -495,7 +506,8 @@ tar_stan_mcmc_rep_run_rep <- function(
inc_warmup,
data_copy,
summaries,
summary_args
summary_args,
transform
) {
stan_seed <- data$.seed + 1L
stan_seed <- if_any(is.null(seed), stan_seed, stan_seed + seed)
Expand Down Expand Up @@ -541,6 +553,7 @@ tar_stan_mcmc_rep_run_rep <- function(
output_type = output_type,
summaries = summaries,
summary_args = summary_args,
transform = transform,
variables = variables,
inc_warmup = inc_warmup,
data = data,
Expand Down
4 changes: 3 additions & 1 deletion R/tar_stan_mcmc_rep_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @inheritSection tar_stan_mcmc_rep Seeds
#' @details Draws could take up a lot of storage. If storage becomes
#' excessive, please consider thinning the draws or using
#' `tar_stan_mcmc_rep_summaries()` instead.
#' `tar_stan_mcmc_rep_summary()` instead.
#'
#' Most of the arguments are passed to the `$compile()`
#' and `$sample()` methods of the `CmdStanModel` class. If you
Expand Down Expand Up @@ -112,6 +112,7 @@ tar_stan_mcmc_rep_draws <- function(
inc_warmup = FALSE,
variables = NULL,
data_copy = character(0),
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand Down Expand Up @@ -178,6 +179,7 @@ tar_stan_mcmc_rep_draws <- function(
data_copy = data_copy,
inc_warmup = inc_warmup,
variables = variables,
transform = substitute(transform),
tidy_eval = tidy_eval,
packages = packages,
library = library,
Expand Down
17 changes: 12 additions & 5 deletions R/tar_stan_vb_rep.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tar_stan_vb_rep <- function(
variables = NULL,
summaries = NULL,
summary_args = NULL,
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand All @@ -75,6 +76,7 @@ tar_stan_vb_rep <- function(
targets::tar_assert_chr(stan_files)
targets::tar_assert_unique(stan_files)
lapply(stan_files, assert_stan_file)
assert_transform(transform)
name_stan <- produce_stan_names(stan_files)
name_file <- paste0(name, "_file")
name_lines <- paste0(name, "_lines")
Expand Down Expand Up @@ -139,7 +141,8 @@ tar_stan_vb_rep <- function(
data_copy = data_copy,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
command <- as.expression(as.call(args))
pattern_data <- substitute(map(x), env = list(x = sym_batch))
Expand Down Expand Up @@ -304,7 +307,8 @@ tar_stan_vb_rep_run <- function(
data_copy,
variables,
summaries,
summary_args
summary_args,
transform
) {
if (!is.null(stdout)) {
withr::local_output_sink(new = stdout, append = TRUE)
Expand Down Expand Up @@ -357,7 +361,8 @@ tar_stan_vb_rep_run <- function(
data_copy = data_copy,
variables = variables,
summaries = summaries,
summary_args = summary_args
summary_args = summary_args,
transform = transform
)
)
out$.file <- stan_path
Expand Down Expand Up @@ -389,7 +394,8 @@ tar_stan_vb_rep_run_rep <- function(
data_copy,
variables,
summaries,
summary_args
summary_args,
transform
) {
stan_seed <- data$.seed + 1L
stan_seed <- if_any(is.null(seed), stan_seed, stan_seed + seed)
Expand Down Expand Up @@ -426,6 +432,7 @@ tar_stan_vb_rep_run_rep <- function(
inc_warmup = NULL,
data = data,
data_copy = data_copy,
seed = stan_seed
seed = stan_seed,
transform = transform
)
}
2 changes: 2 additions & 0 deletions R/tar_stan_vb_rep_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ tar_stan_vb_rep_draws <- function(
sig_figs = NULL,
data_copy = character(0),
variables = NULL,
transform = NULL,
tidy_eval = targets::tar_option_get("tidy_eval"),
packages = targets::tar_option_get("packages"),
library = targets::tar_option_get("library"),
Expand Down Expand Up @@ -153,6 +154,7 @@ tar_stan_vb_rep_draws <- function(
variables = variables,
summaries = NULL,
summary_args = NULL,
transform = substitute(transform),
tidy_eval = tidy_eval,
packages = packages,
library = library,
Expand Down
23 changes: 23 additions & 0 deletions R/utils_assert.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,26 @@ assert_inc_warmup_fit <- function(inc_warmup, inc_warmup_fit) {
)
}
}

assert_transform <- function(transform) {
if (is.null(transform)) {
return()
}
if (!is.symbol(transform)) {
targets::tar_throw_validate("transform must be a symbol or NULL.")
}
name <- as.character(transform)
msg <- "transform must be a function in the pipeline environment."
if (!(name %in% names(targets::tar_option_get("envir")))) {
targets::tar_throw_validate(msg)
}
fun <- targets::tar_option_get("envir")[[name]]
if (!is.function(fun)) {
targets::tar_throw_validate(msg)
}
args <- names(formals(fun))
if (!all(c("data", "draws") %in% args)) {
msg <- "transform must have arguments \"data\" and \"draws\""
targets::tar_throw_validate(msg)
}
}
27 changes: 24 additions & 3 deletions R/utils_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ tar_stan_output <- function(
output_type,
summaries,
summary_args,
transform,
variables,
inc_warmup,
data,
Expand All @@ -29,7 +30,13 @@ tar_stan_output <- function(
summary_args = summary_args,
variables = variables
),
draws = tar_stan_output_draws(fit, variables, inc_warmup),
draws = tar_stan_output_draws(
fit = fit,
data = data,
variables = variables,
inc_warmup = inc_warmup,
transform = transform
),
diagnostics = tar_stan_output_diagnostics(fit, inc_warmup)
)
out <- tibble::as_tibble(out)
Expand Down Expand Up @@ -57,13 +64,27 @@ tar_stan_output_summary <- function(
eval(command)
}

tar_stan_output_draws <- function(fit, variables, inc_warmup) {
tar_stan_output_draws <- function(
fit,
data,
variables,
inc_warmup,
transform
) {
out <- if_any(
is.null(inc_warmup),
fit$draws(variables = variables),
fit$draws(variables = variables, inc_warmup = inc_warmup)
)
tibble::as_tibble(posterior::as_draws_df(out))
out <- tibble::as_tibble(posterior::as_draws_df(out))
if (!is.null(transform)) {
out <- do.call(
what = transform,
args = list(data = data, draws = out),
envir = targets::tar_option_get("envir")
)
}
out
}

tar_stan_output_diagnostics <- function(fit, inc_warmup) {
Expand Down
8 changes: 8 additions & 0 deletions man/tar_stan_gq_rep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f37eb66

Please sign in to comment.