Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmdstanr and posterior updates #10

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ jobs:
cache-version: 1
extra-packages: |
stan-dev/cmdstanr
stan-dev/posterior
rcmdcheck
checkmate
jsonlite
posterior
processx
R6
BH
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ jobs:
cache-version: 1
extra-packages: |
stan-dev/cmdstanr
stan-dev/posterior
rcmdcheck
checkmate
jsonlite
posterior
processx
R6
BH
Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ RoxygenNote: 7.2.3
Depends:
R (>= 3.1.2)
Imports:
abind,
checkmate,
matrixStats (>= 0.52),
posterior,
posterior (>= 1.5.0),
stats
Suggests:
bayesplot,
Expand All @@ -26,6 +25,7 @@ Suggests:
rstan,
testthat
Enhances:
brms,
cmdstanr
VignetteBuilder: knitr
Config/testthat/parallel: true
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(constrain_draws,stanfit)
S3method(moment_match,CmdStanFit)
S3method(moment_match,brmsfit)
S3method(moment_match,draws_array)
S3method(moment_match,draws_df)
S3method(moment_match,draws_list)
Expand Down
176 changes: 176 additions & 0 deletions R/brmsfit_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#' Generic importance weighted moment matching algorithm for `brmsfit` objects.
#' See additional arguments from `moment_match.matrix`
#'
#' @param x A fitted `brmsfit` object.
#' @param log_prob_target_fun Log density of the target. The function
#' takes argument `draws`, which are the unconstrained draws.
#' Can also take the argument `fit` which is the stan model fit.
#' @param log_ratio_fun Log of the density ratio (target/proposal).
#' The function takes argument `draws`, which are the unconstrained
#' draws. Can also take the argument `fit` which is the stan model fit.
#' @param target_observation_weights A vector of weights for observations for
#' defining the target distribution. A value 0 means dropping the observation,
#' a value 1 means including the observation similarly as in the current data,
#' and a value 2 means including the observation twice.
#' @param expectation_fun Optional argument, NULL by default. A
#' function whose expectation is being computed. The function takes
#' arguments `draws`.
#' @param log_expectation_fun Logical indicating whether the
#' expectation_fun returns its values as logarithms or not. Defaults
#' to FALSE. If set to TRUE, the expectation function must be
#' nonnegative (before taking the logarithm). Ignored if
#' `expectation_fun` is NULL.
#' @param constrain Logical specifying whether to return draws on the
#' constrained space? Default is TRUE.
#' @param ... Further arguments passed to `moment_match.matrix`.
#'
#' @return Returns a list with 3 elements: transformed draws, updated
#' importance weights, and the pareto k diagnostic value. If expectation_fun
#' is given, also returns the expectation.
#'
#' @export
moment_match.brmsfit <- function(x,
log_prob_target_fun = NULL,
log_ratio_fun = NULL,
target_observation_weights = NULL,
expectation_fun = NULL,
log_expectation_fun = FALSE,
constrain = TRUE,
...) {
if (!is.null(target_observation_weights) && (!is.null(log_prob_target_fun) || !is.null(log_ratio_fun))) {
stop("You must give only one of target_observation_weights, log_prob_target_fun, or log_ratio_fun.")
}

# ensure draws are in matrix form
draws <- posterior::as_draws_matrix(x)
# draws <- as.matrix(draws)

if (!is.null(target_observation_weights)) {
out <- tryCatch(log_lik(x),

Check warning on line 49 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=49,col=21,[object_usage_linter] no visible global function definition for 'log_lik'
error = function(cond) {

Check warning on line 50 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=50,col=20,[indentation_linter] Indentation should be 6 spaces but is 20 spaces.
message(cond)
message("\nYour brmsfit does not include a parameter called log_lik.")
message("This should not happen. Perhaps you are using an unsupported observation model?")
return(NA)
}
)

function(draws, fit, extra_data, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit, newdata = extra_data)

Check warning on line 60 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=60,col=13,[object_usage_linter] no visible global function definition for 'log_lik'
rowSums(ll)
}

log_ratio_fun <- function(draws, fit, ...) {
fit <- brms:::.update_pars(x = fit, upars = draws)
ll <- log_lik(fit)

Check warning on line 66 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=66,col=13,[object_usage_linter] no visible global function definition for 'log_lik'
colSums(t(drop(ll)) * (target_observation_weights - 1))
}
}


# transform the draws to unconstrained space
udraws <- unconstrain_draws.brmsfit(x, draws = draws, ...)

out <- moment_match.matrix(
# as.matrix(udraws),
udraws,
log_prob_prop_fun = log_prob_draws.brmsfit,
log_prob_target_fun = log_prob_target_fun,
log_ratio_fun = log_ratio_fun,
expectation_fun = expectation_fun,
log_expectation_fun = log_expectation_fun,
fit = x,
...
)

# TODO: this does not work for some reason
# x <- brms:::.update_pars(x = x, upars = out$draws)
# x <- update_pars_brmsfit(x = x, draws = out$draws)

if (constrain) {
out$draws <- constrain_draws.stanfit(x$fit, out$draws, ...)
}

list(adapted_importance_sampling = out,
brmsfit_object = x)
}


log_prob_draws.brmsfit <- function(fit, draws, ...) {
# x <- update_misc_env(x, only_windows = TRUE)
log_prob_draws.stanfit(fit$fit, draws = draws, ...)
}

unconstrain_draws.brmsfit <- function(x, draws, ...) {
unconstrain_draws.stanfit(x$fit, draws = draws, ...)
}

constrain_draws.brmsfit <- function(x, udraws, ...) {
out <- rstan::constrain_pars(udraws, object = x$fit)
out[x$exclude] <- NULL
out
}

# # transform parameters to the constraint space
update_pars_brmsfit <- function(x, draws, ...) {
# list with one element per posterior draw
pars <- apply(draws, 1, constrain_draws.brmsfit, x = x)
# select required parameters only
pars <- lapply(pars, "[", x$fit@sim$pars_oi_old)
# transform draws
ndraws <- length(pars)
pars <- unlist(pars)
npars <- length(pars) / ndraws
dim(pars) <- c(npars, ndraws)
# add dummy 'lp__' draws
pars <- rbind(pars, rep(0, ndraws))
# bring draws into the right structure
new_draws <- named_list(x$fit@sim$fnames_oi_old, list(numeric(ndraws)))
if (length(new_draws) != nrow(pars)) {
stop2("Updating parameters in `update_pars_brmsfit' failed.")

Check warning on line 131 in R/brmsfit_functions.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brmsfit_functions.R,line=131,col=5,[object_usage_linter] no visible global function definition for 'stop2'
}
for (i in seq_len(npars)) {
new_draws[[i]] <- pars[i, ]
}
# create new sim object to overwrite x$fit@sim
x$fit@sim <- list(
samples = list(new_draws),
iter = ndraws,
thin = 1,
warmup = 0,
chains = 1,
n_save = ndraws,
warmup2 = 0,
permutation = list(seq_len(ndraws)),
pars_oi = x$fit@sim$pars_oi_old,
dims_oi = x$fit@sim$dims_oi_old,
fnames_oi = x$fit@sim$fnames_oi_old,
n_flatnames = length(x$fit@sim$fnames_oi_old)
)
x$fit@stan_args <- list(
list(chain_id = 1, iter = ndraws, thin = 1, warmup = 0)
)
brms::rename_pars(x)
}

# update .MISC environment of the stanfit object
# allows to call log_prob and other C++ using methods
# on objects not created in the current R session
# or objects created via another backend
# update_misc_env <- function(x, recompile = FALSE, only_windows = FALSE) {
# stopifnot(is.brmsfit(x))
# recompile <- as_one_logical(recompile)
# only_windows <- as_one_logical(only_windows)
# if (recompile || !has_rstan_model(x)) {
# x <- add_rstan_model(x, overwrite = TRUE)
# } else if (os_is_windows() || !only_windows) {
# # TODO: detect when updating .MISC is not required
# # TODO: find a more efficient way to update .MISC
# old_backend <- x$backend
# x$backend <- "rstan"
# x$fit@.MISC <- suppressMessages(brm(fit = x, chains = 0))$fit@.MISC
# x$backend <- old_backend
# }
# x
# }
1 change: 1 addition & 0 deletions R/constrain_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ constrain_draws <- function(x, ...) {
UseMethod("constrain_draws")
}

##' @export
constrain_draws.CmdStanFit <- function(x, udraws, ...) {
# list with one element per posterior draw
draws <- apply(udraws, 1, x$constrain_variables)
Expand Down
45 changes: 23 additions & 22 deletions R/moment_match.R
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,14 @@
target density is equal to your proposal density.")
}

pareto_smoothed_w <- posterior::pareto_smooth(exp(lw - matrixStats::logSumExp(lw)),
tail = "right", extra_diags = TRUE, r_eff = 1
pareto_smoothed_w <- posterior::pareto_smooth(
lw - matrixStats::logSumExp(lw), are_log_weights = TRUE,
tail = "right", extra_diags = TRUE, r_eff = 1,
return_k = TRUE,
verbose = FALSE
)
k <- pareto_smoothed_w$diagnostics$khat
lw <- log(as.vector(pareto_smoothed_w$x))
lw <- as.vector(pareto_smoothed_w$x)

if (any(is.infinite(k))) {
stop("Something went wrong, and encountered infinite Pareto k values..")
Expand Down Expand Up @@ -339,8 +342,10 @@
lwf <- compute_lwf(draws, lw, expectation_fun, log_expectation_fun, draws_transformation_fun, ...)

pareto_smoothed_wf <- apply(lwf, 2, function(x) {
posterior::pareto_smooth(exp(x),
tail = "right", extra_diags = TRUE, r_eff = 1
posterior::pareto_smooth(
x, are_log_weights = TRUE,
tail = "right", extra_diags = TRUE, r_eff = 1,
return_k = TRUE, verbose = FALSE
)
})
pareto_smoothed_wf <- do.call(mapply, c(cbind, pareto_smoothed_wf))
Expand Down Expand Up @@ -369,7 +374,7 @@
that return a matrix. As a workaround, you can wrap your function
call using apply.")
}
lwf <- log(as.vector(pareto_smoothed_wf$x))
lwf <- as.vector(pareto_smoothed_wf$x)

if (is.null(log_prob_target_fun) && is.null(log_ratio_fun)) {
update_properties <- list(
Expand Down Expand Up @@ -598,26 +603,22 @@
# and add tests for that

# transform the model parameters to unconstrained space
udraws <- x$unconstrain_draws()
udraws <- aperm(
abind::abind(
lapply(udraws, function(x) abind::abind(x, along = 2)),
along = 3
),
perm = c(2, 3, 1)
)
udraws <- x$unconstrain_draws(format = "draws_matrix")

Check warning on line 606 in R/moment_match.R

View workflow job for this annotation

GitHub Actions / lint

file=R/moment_match.R,line=606,col=4,[indentation_linter] Indentation should be 2 spaces but is 4 spaces.

udraws <- matrix(
udraws,
nrow = dim(udraws)[1] * dim(udraws)[2],
ncol = dim(udraws)[3]
)
if (constrain_draws) {
draws_transformation_fun <- function(draws, ...) {
return(constrain_draws(x, draws, ...))
}
} else {
draws_transformation_fun <- NULL
}

out <- moment_match.matrix(
out <- moment_match(
x = udraws,
log_prob_prop_fun = log_prob_draws.CmdStanFit,
log_prob_target_fun = log_prob_target_fun,
log_ratio_fun = log_ratio_fun,
draws_transformation_fun = draws_transformation_fun,
fit = x,
...
)
Expand Down Expand Up @@ -697,11 +698,11 @@


# transform the draws to unconstrained space
udraws <- unconstrain_draws.stanfit(x, draws = draws, ...)
udraws <- unconstrain_draws(x, draws = draws, ...)

if (constrain_draws) {
draws_transformation_fun <- function(draws, ...) {
return(constrain_draws.stanfit(x, draws, ...))
return(constrain_draws(x, draws, ...))
}
} else {
draws_transformation_fun <- NULL
Expand Down
11 changes: 7 additions & 4 deletions R/update_quantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ update_quantities <- function(draws, orig_log_prob_prop,
)
}

pareto_smoothed_w_new <- posterior::pareto_smooth(exp(lw_new - matrixStats::logSumExp(lw_new)),
tail = "right", r_eff = 1,
return_k = TRUE
pareto_smoothed_w_new <- posterior::pareto_smooth(
lw_new - matrixStats::logSumExp(lw_new),
are_log_weights = TRUE,
tail = "right", r_eff = 1, return_k = TRUE,
verbose = FALSE
)
k <- pareto_smoothed_w_new$diagnostics$khat
lw <- log(as.vector(pareto_smoothed_w_new$x))
lw <- as.vector(pareto_smoothed_w_new$x)

# normalize log weights
lw <- lw - matrixStats::logSumExp(lw)

Expand Down
Loading
Loading