Skip to content

Commit

Permalink
Merge pull request #238 from stan-dev/avehtari-patch-1
Browse files Browse the repository at this point in the history
Doc improvement in loo_subsample.R
  • Loading branch information
avehtari committed Jan 5, 2024
2 parents eb2e5ec + e3525c3 commit a69414b
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions R/loo_subsample.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#' Efficient approximate leave-one-out cross-validation (LOO) using subsampling
#' Efficient approximate leave-one-out cross-validation (LOO) using subsampling,
#' so that less costly and more approximate computation is made for all LOO-fold,
#' and more costly and accurate computations are made only for m<N LOO-folds.
#'
#' @param x A function. The **Methods (by class)** section, below, has detailed
#' descriptions of how to specify the inputs.
Expand Down Expand Up @@ -84,15 +86,17 @@ loo_subsample <- function(x, ...) {
#' @param estimator How should `elpd_loo`, `p_loo` and `looic` be estimated?
#' The default is `"diff_srs"`.
#' * `"diff_srs"`: uses the difference estimator with simple random sampling
#' (srs). `p_loo` is estimated using standard srs.
#' * `"hh"`: uses the Hansen-Hurwitz estimator with sampling proportional to
#' size, where `abs` of loo_approximation is used as size.
#' without replacement (srs). `p_loo` is estimated using standard srs.
#' (Magnusson et al., 2020)
#' * `"hh"`: uses the Hansen-Hurwitz estimator with sampling with replacement
#' proportional to size, where `abs` of loo_approximation is used as size.
#' (Magnusson et al., 2019)
#' * `"srs"`: uses simple random sampling and ordinary estimation.
#'
#' @param llgrad The gradient of the log-likelihood. This
#' is only used when `loo_approximation` is `"waic_grad"`,
#' `"waic_grad_marginal"`, or `"waic_hess"`. The default is `NULL`.
#' @param llhess The hessian of the log-likelihood. This is only used
#' @param llhess The Hessian of the log-likelihood. This is only used
#' with `loo_approximation = "waic_hess"`. The default is `NULL`.
#'
loo_subsample.function <-
Expand Down Expand Up @@ -814,7 +818,7 @@ pps_sample <- function(m, pis) {

## Constructor ---

#' Construct a `psis_loo_ss} object
#' Construct a `psis_loo_ss` object
#'
#' @noRd
#' @param x A `psis_loo` object.
Expand Down Expand Up @@ -1052,7 +1056,7 @@ update_m_i_in_pointwise <- function(pointwise, idxs, type = "replace") {

## Estimation ---

#' Estimate the elpd using the Hansen-Hurwitz estimator
#' Estimate the elpd using the Hansen-Hurwitz estimator (Magnusson et al., 2019)
#' @noRd
#' @param x A `psis_loo_ss` object.
#' @return A `psis_loo_ss` object.
Expand Down Expand Up @@ -1085,7 +1089,7 @@ loo_subsample_estimation_hh <- function(x) {
update_psis_loo_ss_estimates(x)
}

#' Update a `psis_loo_ss} object with generic estimates
#' Update a `psis_loo_ss` object with generic estimates
#'
#' @noRd
#' @details
Expand All @@ -1110,7 +1114,7 @@ update_psis_loo_ss_estimates <- function(x) {
x
}

#' Weighted Hansen-Hurwitz estimator
#' Weighted Hansen-Hurwitz estimator (Magnusson et al., 2019)
#' @noRd
#' @param z Normalized probabilities for the observation.
#' @param m_i The number of times obs i was selected.
Expand All @@ -1133,7 +1137,7 @@ whhest <- function(z, m_i, y, N) {
}


#' Estimate elpd using the difference estimator and srs wor
#' Estimate elpd using the difference estimator and SRS-WOR (Magnusson et al., 2020)
#' @noRd
#' @param x A `psis_loo_ss` object.
#' @return A `psis_loo_ss` object.
Expand All @@ -1153,7 +1157,7 @@ loo_subsample_estimation_diff_srs <- function(x) {
update_psis_loo_ss_estimates(x)
}

#' Difference estimation using SRS-WOR sampling
#' Difference estimation using SRS-WOR sampling (Magnusson et al., 2020)
#' @noRd
#' @param y_approx Approximated values of all observations.
#' @param y The values observed.
Expand All @@ -1175,8 +1179,14 @@ srs_diff_est <- function(y_approx, y, y_idx) {
t_hat_epsilon <- N * mean(y^2 - y_approx_m^2)

est_list <- list(m = length(y), N = N)
# eq (7)
est_list$y_hat <- t_pi_tilde + t_e
# eq (8)
est_list$v_y_hat <- N^2 * (1 - m / N) * var(e_i) / m
# eq (9) first row second `+` should be `-`
# Supplementary material eq (6) has this correct
# Here the variance is for sum, while in the paper the variance is for mean
# which explains the proporional difference of 1/n
est_list$hat_v_y <- (t_pi2_tilde + t_hat_epsilon) - # a (has been checked)
(1/N) * (t_e^2 - est_list$v_y_hat + 2 * t_pi_tilde * est_list$y_hat - t_pi_tilde^2) # b
est_list
Expand Down

0 comments on commit a69414b

Please sign in to comment.