Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid R cmd check NOTEs about some internal functions #240

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ S3method(as.psis_loo,psis_loo)
S3method(as.psis_loo,psis_loo_ss)
S3method(as.psis_loo_ss,psis_loo)
S3method(as.psis_loo_ss,psis_loo_ss)
S3method(compute_point_estimate,default)
S3method(compute_point_estimate,matrix)
S3method(crps,matrix)
S3method(crps,numeric)
S3method(dim,importance_sampling)
Expand All @@ -33,12 +35,15 @@ S3method(loo_approximate_posterior,"function")
S3method(loo_approximate_posterior,array)
S3method(loo_approximate_posterior,matrix)
S3method(loo_compare,default)
S3method(loo_compare,psis_loo_ss_list)
S3method(loo_crps,matrix)
S3method(loo_model_weights,default)
S3method(loo_moment_match,default)
S3method(loo_predictive_metric,matrix)
S3method(loo_scrps,matrix)
S3method(loo_subsample,"function")
S3method(n_draws,default)
S3method(n_draws,matrix)
S3method(nobs,psis_loo_ss)
S3method(plot,loo)
S3method(plot,psis)
Expand Down Expand Up @@ -78,6 +83,9 @@ S3method(scrps,numeric)
S3method(sis,array)
S3method(sis,default)
S3method(sis,matrix)
S3method(thin_draws,default)
S3method(thin_draws,matrix)
S3method(thin_draws,numeric)
S3method(tis,array)
S3method(tis,default)
S3method(tis,matrix)
Expand All @@ -86,9 +94,6 @@ S3method(waic,"function")
S3method(waic,array)
S3method(waic,matrix)
S3method(weights,importance_sampling)
export(.compute_point_estimate)
export(.ndraws)
export(.thin_draws)
export(E_loo)
export(compare)
export(crps)
Expand Down
1 change: 1 addition & 0 deletions R/loo_compare.psis_loo_ss_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' @param ... Currently ignored.
#' @return A `compare.loo_ss` object.
#' @author Mans Magnusson
#' @export
loo_compare.psis_loo_ss_list <- function(x, ...) {

checkmate::assert_list(x, any.missing = FALSE, min.len = 1)
Expand Down
84 changes: 41 additions & 43 deletions R/loo_subsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ loo_subsample.function <-
cores <- loo_cores(cores)

checkmate::assert_choice(loo_approximation, choices = loo_approximation_choices(), null.ok = FALSE)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = .ndraws(draws), null.ok = TRUE)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = n_draws(draws), null.ok = TRUE)
checkmate::assert_choice(estimator, choices = estimator_choices())

.llgrad <- .llhess <- NULL
Expand Down Expand Up @@ -234,7 +234,7 @@ loo_subsample.function <-
.llgrad = .llgrad,
.llhess = .llhess,
data_dim = dim(data),
ndraws = .ndraws(draws))
ndraws = n_draws(draws))
loo_ss
}

Expand Down Expand Up @@ -379,13 +379,13 @@ update.psis_loo_ss <- function(object, ...,

if (length(observations) == 1) {
# Add new samples pointwise and diagnostic
object <- rbind.psis_loo_ss(object, x = loo_obj)
object <- rbind_psis_loo_ss(object, x = loo_obj)

# Update m_i for current pointwise (diagnostic stay the same)
object$pointwise <- update_m_i_in_pointwise(object$pointwise, cidxs$add, type = "add")
} else {
# Add new samples pointwise and diagnostic
object <- rbind.psis_loo_ss(object, loo_obj)
object <- rbind_psis_loo_ss(object, loo_obj)

# Replace m_i current pointwise and diagnostics
object$pointwise <- update_m_i_in_pointwise(object$pointwise, cidxs$add, type = "replace")
Expand Down Expand Up @@ -537,20 +537,20 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
if (loo_approximation == "none") return(rep(1L,N))

if (loo_approximation %in% c("tis", "sis")) {
draws <- .thin_draws(draws, loo_approximation_draws)
draws <- thin_draws(draws, loo_approximation_draws)
is_values <- suppressWarnings(loo.function(.llfun, data = data, draws = draws, is_method = loo_approximation))
return(is_values$pointwise[, "elpd_loo"])
}

if (loo_approximation == "waic") {
draws <- .thin_draws(draws, loo_approximation_draws)
draws <- thin_draws(draws, loo_approximation_draws)
waic_full_obj <- waic.function(.llfun, data = data, draws = draws)
return(waic_full_obj$pointwise[,"elpd_waic"])
}

# Compute the lpd or log p(y_i|y_{-i})
if (loo_approximation == "lpd") {
draws <- .thin_draws(draws, loo_approximation_draws)
draws <- thin_draws(draws, loo_approximation_draws)
lpds <- compute_lpds(N, data, draws, .llfun, cores)
return(lpds) # Use only the lpd
}
Expand All @@ -561,8 +561,8 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
loo_approximation == "waic_grad_marginal" |
loo_approximation == "waic_hess") {

draws <- .thin_draws(draws, loo_approximation_draws)
point_est <- .compute_point_estimate(draws)
draws <- thin_draws(draws, loo_approximation_draws)
point_est <- compute_point_estimate(draws)
lpds <- compute_lpds(N, data, point_est, .llfun, cores)
if (loo_approximation == "plpd") return(lpds) # Use only the lpd
}
Expand All @@ -572,7 +572,7 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
loo_approximation == "waic_hess") {
checkmate::assert_true(!is.null(.llgrad))

point_est <- .compute_point_estimate(draws)
point_est <- compute_point_estimate(draws)
# Compute the lpds
lpds <- compute_lpds(N, data, point_est, .llfun, cores)

Expand Down Expand Up @@ -620,79 +620,77 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation

#' Compute a point estimate from a draws object
#'
#' @details This is a generic function to thin draws from arbitrary draws
#' @noRd
#' @details This is a generic function to compute point estimates from draws
#' objects. The function is internal and should only be used by developers to
#' enable [loo_subsample()] for arbitrary draws objects.
#'
#' @param draws A draws object with draws from the posterior.
#' @return A 1 by P matrix with point estimates from a draws object.
#' @keywords internal
#' @export
.compute_point_estimate <- function(draws) {
UseMethod(".compute_point_estimate")
compute_point_estimate <- function(draws) {
UseMethod("compute_point_estimate")
}

.compute_point_estimate.matrix <- function(draws) {
#' @export
compute_point_estimate.matrix <- function(draws) {
t(as.matrix(colMeans(draws)))
}

.compute_point_estimate.default <- function(draws) {
stop(".compute_point_estimate() has not been implemented for objects of class '", class(draws), "'")
#' @export
compute_point_estimate.default <- function(draws) {
stop("compute_point_estimate() has not been implemented for objects of class '", class(draws), "'")
}

#' Thin a draws object
#'
#' @noRd
#' @details This is a generic function to thin draws from arbitrary draws
#' objects. The function is internal and should only be used by developers to
#' enable [loo_subsample()] for arbitrary draws objects.
#'
#' @param draws A draws object with posterior draws.
#' @param loo_approximation_draws The number of posterior draws to return (ie after thinning).
#' @keywords internal
#' @export
#' @return A thinned draws object.
.thin_draws <- function(draws, loo_approximation_draws) {
UseMethod(".thin_draws")
thin_draws <- function(draws, loo_approximation_draws) {
UseMethod("thin_draws")
}

.thin_draws.matrix <- function(draws, loo_approximation_draws) {
#' @export
thin_draws.matrix <- function(draws, loo_approximation_draws) {
if (is.null(loo_approximation_draws)) return(draws)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = .ndraws(draws), null.ok = TRUE)
S <- .ndraws(draws)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = n_draws(draws), null.ok = TRUE)
S <- n_draws(draws)
idx <- 1:loo_approximation_draws * S %/% loo_approximation_draws
draws <- draws[idx, , drop = FALSE]
draws
}

.thin_draws.numeric <- function(draws, loo_approximation_draws) {
.thin_draws.matrix(as.matrix(draws), loo_approximation_draws)
#' @export
thin_draws.numeric <- function(draws, loo_approximation_draws) {
thin_draws.matrix(as.matrix(draws), loo_approximation_draws)
}

.thin_draws.default <- function(draws, loo_approximation_draws) {
stop(".thin_draws() has not been implemented for objects of class '", class(draws), "'")
#' @export
thin_draws.default <- function(draws, loo_approximation_draws) {
stop("thin_draws() has not been implemented for objects of class '", class(draws), "'")
}


#' The number of posterior draws in a draws object.
#'
#' @noRd
#' @details This is a generic function to return the total number of draws from
#' an arbitrary draws objects. The function is internal and should only be
#' used by developers to enable [loo_subsample()] for arbitrary draws objects.
#'
#' @param x A draws object with posterior draws.
#' @return An integer with the number of draws.
#' @keywords internal
#' @export
.ndraws <- function(x) {
UseMethod(".ndraws")
n_draws <- function(x) {
UseMethod("n_draws")
}

.ndraws.matrix <- function(x) {
#' @export
n_draws.matrix <- function(x) {
nrow(x)
}

.ndraws.default <- function(x) {
stop(".ndraws() has not been implemented for objects of class '", class(x), "'")
#' @export
n_draws.default <- function(x) {
stop("n_draws() has not been implemented for objects of class '", class(x), "'")
}

## Subsampling -----
Expand Down Expand Up @@ -969,7 +967,7 @@ add_subsampling_vars_to_pointwise <- function(pointwise, idxs, elpd_loo_approx)
#' @param object A `psis_loo_ss` object.
#' @param x A `psis_loo` object.
#' @return An updated `psis_loo_ss` object.
rbind.psis_loo_ss <- function(object, x) {
rbind_psis_loo_ss <- function(object, x) {
checkmate::assert_class(object, "psis_loo_ss")
if (is.null(x)) return(object) # Fallback
checkmate::assert_class(x, "psis_loo")
Expand Down
23 changes: 0 additions & 23 deletions man/dot-compute_point_estimate.Rd

This file was deleted.

23 changes: 0 additions & 23 deletions man/dot-ndraws.Rd

This file was deleted.

25 changes: 0 additions & 25 deletions man/dot-thin_draws.Rd

This file was deleted.

Loading