Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for target_observation_weights #15

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
14 changes: 10 additions & 4 deletions R/constrain_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ constrain_draws.stanfit <- function(x, udraws, ...) {
}
}

new_draws <- posterior::as_draws_array(new_draws)
# new_draws <- posterior::as_draws_array(new_draws)
new_draws <- posterior::as_draws_matrix(new_draws)

new_draws
}


#' @export
constrain_draws.brmsfit <- function(x, draws, ...) {
# list with one element per posterior draw
Expand All @@ -93,9 +93,14 @@ constrain_draws.brmsfit <- function(x, draws, ...) {
lp__ <- log_prob_draws.stanfit(x, draws = udraws, ...)
draws <- rbind(draws, lp__ = lp__)

# TODO: what is difference between fnames_oi_old and fnames_oi
# bring draws into the right structure
# new_draws <- named_list(
# x@sim$fnames_oi_old,
# list(numeric(ndraws))
# )
new_draws <- named_list(
x@sim$fnames_oi_old,
x@sim$fnames_oi,
list(numeric(ndraws))
)

Expand All @@ -109,5 +114,6 @@ constrain_draws.brmsfit <- function(x, draws, ...) {
}
}

posterior::as_draws_array(new_draws)
# posterior::as_draws_array(new_draws)
posterior::as_draws_matrix(new_draws)
}
5 changes: 1 addition & 4 deletions R/example_iwmm_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ iwmm_examples <- function() {
}
parameters {
real mu;
real log_sigma;
}
transformed parameters {
real<lower=0> sigma = exp(log_sigma);
real<lower=0> sigma;
}
model {
target += normal_lpdf(x | mu, sigma);
Expand Down
2 changes: 0 additions & 2 deletions R/log_prob_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ log_prob_draws.stanfit <- function(fit, draws, ...) {
)
}



#' @export
log_prob_draws.brmsfit <- function(fit, draws, ...) {
# x <- update_misc_env(x, only_windows = TRUE)
Expand Down
88 changes: 66 additions & 22 deletions R/moment_match.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ moment_match.draws_rvars <- function(x,
#' @param draws_transformation_fun Optional argument, NULL by default. A
#' function that transforms draws before computing expectation. The function takes
#' arguments `draws`.
#' @param tdraws_fun Optional argument, NULL by default. A
#' function that transforms draws that are returned. The function takes
#' arguments `draws`.
#' @param is_method Which importance sampling method to use. Currently only `psis` is supported.
#' @param adaptation_method Which adaptation method to use. Currently only `iwmm` is supported.
#' @param k_threshold Threshold value for Pareto k values above which
Expand Down Expand Up @@ -211,6 +214,7 @@ moment_match.matrix <- function(x,
expectation_fun = NULL,
log_expectation_fun = FALSE,
draws_transformation_fun = NULL,
tdraws_fun = NULL,
is_method = "psis",
adaptation_method = "iwmm",
k_threshold = 0.5,
Expand Down Expand Up @@ -549,9 +553,16 @@ moment_match.matrix <- function(x,
}

if (!is.null(draws_transformation_fun)) {
draws <- draws_transformation_fun(draws)
if (!is.null(tdraws_fun)) {
tdraws <- tdraws_fun(draws)
} else {
tdraws <- draws_transformation_fun(draws)
}
unweighted_expectation <- expectation_fun(tdraws, ...)
} else {
tdraws <- NA
unweighted_expectation <- expectation_fun(draws, ...)
}
unweighted_expectation <- expectation_fun(draws, ...)

if (log_expectation_fun) {
expectation <- exp(matrixStats::colLogSumExps(
Expand All @@ -562,8 +573,13 @@ moment_match.matrix <- function(x,
expectation <- colSums(w * unweighted_expectation)
}

# TODO: How to distinguish naming of transformation we take as input
# and transformation we do in moment matching?
# Now draws is the draws transformed by moment matching
# tdraws is draws additionally transformed by draws_transformation_fun
adapted_draws <- list(
draws = draws,
tdraws = tdraws,
log_weights = lw,
expectation = expectation,
diagnostics = list(
Expand All @@ -587,7 +603,11 @@ moment_match.matrix <- function(x,
#' @param log_ratio_fun Log of the density ratio (target/proposal).
#' The function takes argument `draws`, which are the unconstrained
#' draws.
#' @param constrain_draws Logical specifying whether to return draws on the
#' @param target_observation_weights A vector of weights for observations for
#' defining the target distribution. A value 0 means dropping the observation,
#' a value 1 means including the observation similarly as in the current data,
#' and a value 2 means including the observation twice.
#' @param constrain Logical specifying whether to return draws on the
#' constrained space. Draws are also constrained for computing expectations. Default is TRUE.
#' @param ... Further arguments passed to `moment_match.matrix`.
#'
Expand All @@ -598,15 +618,22 @@ moment_match.matrix <- function(x,
moment_match.CmdStanFit <- function(x,
log_prob_target_fun = NULL,
log_ratio_fun = NULL,
constrain_draws = TRUE,
target_observation_weights = NULL,
constrain = TRUE,
...) {
if (!is.null(target_observation_weights) && (!is.null(log_prob_target_fun) || !is.null(log_ratio_fun))) {
stop("You must give only one of target_observation_weights, log_prob_target_fun, or log_ratio_fun.")
}

# TODO: actually implement target_observation_weights

# TODO: support expectation fun?
# and add tests for that

# transform the model parameters to unconstrained space
udraws <- x$unconstrain_draws(format = "draws_matrix")

if (constrain_draws) {
if (constrain) {
draws_transformation_fun <- function(draws, ...) {
return(constrain_draws(x, draws, ...))
}
Expand Down Expand Up @@ -649,7 +676,7 @@ moment_match.CmdStanFit <- function(x,
#' to FALSE. If set to TRUE, the expectation function must be
#' nonnegative (before taking the logarithm). Ignored if
#' `expectation_fun` is NULL.
#' @param constrain_draws Logical specifying whether to return draws on the
#' @param constrain Logical specifying whether to return draws on the
#' constrained space. Draws are also constrained for computing expectations. Default is TRUE.
#' @param ... Further arguments passed to `moment_match.matrix`.
#'
Expand All @@ -664,7 +691,7 @@ moment_match.stanfit <- function(x,
target_observation_weights = NULL,
expectation_fun = NULL,
log_expectation_fun = FALSE,
constrain_draws = TRUE,
constrain = TRUE,
...) {
if (!is.null(target_observation_weights) && (!is.null(log_prob_target_fun) || !is.null(log_ratio_fun))) {
stop("You must give only one of target_observation_weights, log_prob_target_fun, or log_ratio_fun.")
Expand Down Expand Up @@ -697,13 +724,14 @@ moment_match.stanfit <- function(x,
}
}


# transform the draws to unconstrained space
udraws <- unconstrain_draws(x, draws = draws, ...)

if (constrain_draws) {
if (constrain) {
draws_transformation_fun <- function(draws, ...) {
return(constrain_draws(x, draws, ...))
n_pars <- dim(draws)[2]
constrained_draws <- constrain_draws(x, draws, ...)
return(constrained_draws[, 1:n_pars])
}
} else {
draws_transformation_fun <- NULL
Expand Down Expand Up @@ -782,11 +810,12 @@ moment_match.brmsfit <- function(x,
}
)

function(draws, fit, extra_data, ...) {
fit <- .update_pars(x = fit, upars = draws)
ll <- brms::log_lik(fit, newdata = extra_data)
rowSums(ll)
}
# TODO: what is the point of this? was this left by accident?
# function(draws, fit, extra_data, ...) {
# fit <- .update_pars(x = fit, upars = draws)
# ll <- brms::log_lik(fit, newdata = extra_data)
# rowSums(ll)
# }

log_ratio_fun <- function(draws, fit, ...) {
fit <- .update_pars(x = fit, upars = draws)
Expand All @@ -795,9 +824,23 @@ moment_match.brmsfit <- function(x,
}
}


# transform the draws to unconstrained space
udraws <- unconstrain_draws.brmsfit(x, draws = draws, ...)
udraws <- unconstrain_draws(x, draws = draws, ...)

if (constrain) {
draws_transformation_fun <- function(draws, ...) {
n_pars <- dim(draws)[2]
constrained_draws <- constrain_draws(x, draws, ...)
return(constrained_draws[, 1:n_pars])
}
tdraws_fun <- function(draws, ...) {
constrained_draws <- constrain_draws(x, draws, ...)
return(constrained_draws)
}
} else {
draws_transformation_fun <- NULL
tdraws_fun <- NULL
}

out <- moment_match.matrix(
udraws,
Expand All @@ -806,18 +849,19 @@ moment_match.brmsfit <- function(x,
log_ratio_fun = log_ratio_fun,
expectation_fun = expectation_fun,
log_expectation_fun = log_expectation_fun,
draws_transformation_fun = draws_transformation_fun,
tdraws_fun = tdraws_fun,
fit = x,
...
)

x <- .update_pars(x = x, upars = out$draws)

if (constrain) {
out$draws <- posterior::as_draws(x)
udraws <- unconstrain_draws(x, draws = out$tdraws, ...)
x <- .update_pars(x = x, upars = udraws)
} else {
x <- .update_pars(x = x, upars = out$draws)
}

out$fit <- x


out
}
6 changes: 4 additions & 2 deletions R/unconstrain_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
if (is.null(dim(udraws))) {
dim(udraws) <- c(1, length(udraws))
}
out <- posterior::as_draws_matrix(t(udraws))

udraws <- t(udraws)
colnames(udraws) <- colnames(draws)[1:ncol(udraws)]

Check warning on line 23 in R/unconstrain_draws.R

View workflow job for this annotation

GitHub Actions / lint

file=R/unconstrain_draws.R,line=23,col=39,[seq_linter] 1:ncol(...) is likely to be wrong in the empty edge case. Use seq_len(ncol(...)) instead.
out <- posterior::as_draws_matrix(udraws)

out
}


#' @export
unconstrain_draws.brmsfit <- function(x, draws, ...) {
unconstrain_draws.stanfit(x$fit, draws = draws, ...)
Expand Down
82 changes: 82 additions & 0 deletions tests/testthat/test-moment-match-brmsfit-analytical.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,86 @@ if (brms_available) {
tolerance = 0.1
)
})

test_that("moment_match.brmsfit works with target_observation_weights", {
normal_model <- example_iwmm_model("normal_model")

bprior <- c(
prior("", class = "sigma"),
prior("", class = "Intercept")
)

fit <- brm(x ~ 1,
data = normal_model$data,
prior = bprior,
save_pars = save_pars(all = TRUE),
chains = 4,
iter = 1000,
refresh = 0,
seed = 1234
)

expectation_fun_first_moment <- function(draws, ...) {
draws
}

first_moment <- moment_match(
fit,
expectation_fun = expectation_fun_first_moment,
)

expect_equal(
first_moment$expectation,
posterior_summary(fit)[, 1],
tolerance = 0.001
)
# TODO: why lprior and lp__ are zeros?
expect_equal(
posterior_summary(fit, variable = c("b_Intercept", "sigma", "Intercept")),
posterior_summary(first_moment$fit, variable = c("b_Intercept", "sigma", "Intercept")),
tolerance = 0.001
)

# leave-one-out expectation

first_moment_loo <- moment_match(
fit,
target_observation_weights = append(rep(1, 9), 0),
expectation_fun = expectation_fun_first_moment,
)

loo_data <- normal_model$data
loo_data$x <- loo_data$x[-loo_data$N]
loo_data$N <- loo_data$N - 1

fit_loo <- brm(x ~ 1,
data = loo_data,
prior = bprior,
save_pars = save_pars(all = TRUE),
chains = 4,
iter = 1000,
refresh = 0,
seed = 1234
)

# TODO: why lprior and lp__ are zeros?
expect_equal(
first_moment_loo$expectation,
posterior_summary(fit_loo)[, 1],
tolerance = 0.001
)

expect_equal(
first_moment_loo$expectation[c("b_Intercept", "sigma", "Intercept")],
posterior_summary(fit_loo, variable = c("b_Intercept", "sigma", "Intercept"))[, 1],
tolerance = 0.1
)
# TODO: why lprior and lp__ are zeros?
# TODO: check other estimares than mean
expect_equal(
posterior_summary(fit_loo, variable = c("b_Intercept", "sigma", "Intercept"))[, 1],
posterior_summary(first_moment_loo$fit, variable = c("b_Intercept", "sigma", "Intercept"))[, 1],
tolerance = 0.1
)
})
}
Loading
Loading