Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ref_interval option to mcmc_rank_* functions. #248

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions R/mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
#' of rank-normalized MCMC samples. Defaults to `20`.
#' @param ref_line For the rank plots, whether to draw a horizontal line at the
#' average number of ranks per bin. Defaults to `FALSE`.
#' @param ref_interval For the rank plots, whether to draw a reference
#' uncertainty interval based on the expected distribution of the rank histogram
#' bins. Defaults to `FALSE`.
#' @param interval_args If `ref_interval = TRUE`, optional arguments controlling
#' the width and alpha of the reference interval. The default is a `95\%`
#' uncertainty interval plotted with an alpha value of `0.2`. This must be a
#' list with elements named `width` and `alpha`.
#' @export
mcmc_rank_overlay <- function(x,
pars = character(),
Expand All @@ -269,7 +276,9 @@ mcmc_rank_overlay <- function(x,
facet_args = list(),
...,
n_bins = 20,
ref_line = FALSE) {
ref_line = FALSE,
ref_interval = FALSE,
interval_args = list(width = 0.95, alpha = 0.2)) {
check_ignored_arguments(...)
data <- mcmc_trace_data(
x,
Expand All @@ -278,6 +287,14 @@ mcmc_rank_overlay <- function(x,
transformations = transformations
)

# mcmc_rank plots make no sense if there aren't multiple chains
# a rank plot of 1 chain is perfectly uniform by construction, and
# has no power as a diagnostic.
if (!(unique(data$n_chains) > 1)){
STOP_need_multiple_chains()
}

n_iter <- unique(data$n_iterations)
n_chains <- unique(data$n_chains)
n_param <- unique(data$n_parameters)

Expand Down Expand Up @@ -316,6 +333,12 @@ mcmc_rank_overlay <- function(x,
} else {
NULL
}

interval_call <- if (ref_interval) {
rank_polygon_geom(n_iter, n_chains, n_bins, interval_args)
} else {
NULL
}

facet_call <- NULL
if (n_param > 1) {
Expand All @@ -329,6 +352,7 @@ mcmc_rank_overlay <- function(x,
geom_step() +
layer_ref_line +
facet_call +
interval_call +
scale_color +
ylim(c(0, NA)) +
bayesplot_theme_get() +
Expand All @@ -345,7 +369,9 @@ mcmc_rank_hist <- function(x,
...,
facet_args = list(),
n_bins = 20,
ref_line = FALSE) {
ref_line = FALSE,
ref_interval = FALSE,
interval_args = list(width = 0.95, alpha = 0.2)) {
check_ignored_arguments(...)
data <- mcmc_trace_data(
x,
Expand All @@ -354,6 +380,10 @@ mcmc_rank_hist <- function(x,
transformations = transformations
)

if (!(unique(data$n_chains) > 1)){
STOP_need_multiple_chains()
}

n_iter <- unique(data$n_iterations)
n_chains <- unique(data$n_chains)
n_param <- unique(data$n_parameters)
Expand Down Expand Up @@ -396,6 +426,11 @@ mcmc_rank_hist <- function(x,
}

facet_call <- do.call(facet_f, facet_args)
interval_call <- if (ref_interval) {
rank_polygon_geom(n_iter, n_chains, n_bins, interval_args)
} else {
NULL
}

ggplot(data) +
aes_(x = ~ value_rank) +
Expand All @@ -409,6 +444,7 @@ mcmc_rank_hist <- function(x,
layer_ref_line +
geom_blank(data = data_boundaries) +
facet_call +
interval_call +
force_x_axis_in_facets() +
dont_expand_y_axis(c(0.005, 0)) +
bayesplot_theme_get() +
Expand Down Expand Up @@ -681,3 +717,30 @@ divergence_rug <- function(np, np_style, n_iter, n_chain) {
alpha = np_style$alpha[["div"]]
)
}

rank_polygon_geom <- function(n_iter, n_chains, n_bins, interval_args) {
validate_interval_args(interval_args)
polygon_y_vals <- qbinom(
c((1 - interval_args$width) / 2, (1 + interval_args$width) / 2),
size = n_iter,
prob = (n_bins)^(-1)
)

polygon_df <- data.frame(
x = rep(c(0, n_iter * n_chains), each = 2),
y = c(polygon_y_vals, rev(polygon_y_vals))
)

geom_polygon(
mapping = aes(x = x, y = y),
data = polygon_df,
inherit.aes = FALSE,
alpha = interval_args$alpha
)
}

validate_interval_args <- function(interval_args) {
stopifnot(all(names(interval_args) %in% c("width", "alpha")))
stopifnot(interval_args$width %>% dplyr::between(0, 1))
stopifnot(interval_args$alpha %>% dplyr::between(0, 1))
}
17 changes: 15 additions & 2 deletions man/MCMC-traces.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 61 additions & 0 deletions tests/testthat/test-mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,30 @@ test_that("mcmc_trace_highlight throws error if highlight > number of chains", {
expect_error(mcmc_trace_highlight(arr, pars = "sigma", highlight = 7), "'highlight' is 7")
})

test_that("mcmc_rank_hist returns a ggplot object", {
expect_gg(mcmc_rank_hist(arr, pars = "beta[1]", regex_pars = "x\\:"))
expect_gg(mcmc_rank_hist(dframe_multiple_chains))
expect_gg(mcmc_rank_hist(chainlist))
})

test_that("mcmc_rank_overlay returns a ggplot object", {
expect_gg(mcmc_rank_overlay(arr, pars = "beta[1]", regex_pars = "x\\:"))
expect_gg(mcmc_rank_overlay(dframe_multiple_chains))
expect_gg(mcmc_rank_overlay(chainlist))
})

test_that("mcmc_rank_hist errors if there is only 1 chain", {
expect_error(mcmc_rank_hist(mat), "requires multiple")
expect_error(mcmc_rank_hist(dframe), "requires multiple chains")
expect_error(mcmc_rank_hist(arr1chain), "requires multiple chains")
})

test_that("mcmc_rank_overlay errors if there is only 1 chain", {
expect_error(mcmc_rank_overlay(mat), "requires multiple")
expect_error(mcmc_rank_overlay(dframe), "requires multiple chains")
expect_error(mcmc_rank_overlay(arr1chain), "requires multiple chains")
})

# options -----------------------------------------------------------------
test_that("mcmc_trace options work", {
expect_gg(g1 <- mcmc_trace(arr, regex_pars = "beta", window = c(5, 10)))
Expand All @@ -47,6 +71,43 @@ test_that("mcmc_trace options work", {
expect_error(mcmc_trace(arr, n_warmup = 50, iter1 = 20))
})

test_that("mcmc_rank_hist options work", {
expect_gg(mcmc_rank_hist(arr, regex_pars = "beta", ref_interval = TRUE))
expect_gg(
mcmc_rank_hist(arr,
regex_pars = "beta",
n_bins = 15,
ref_line = TRUE,
ref_interval = TRUE,
interval_args = list(width = 0.8, alpha = 0.1))
)
})

test_that("mcmc_rank_overlay options work", {
expect_gg(mcmc_rank_overlay(arr, regex_pars = "beta", ref_interval = TRUE))
expect_gg(
mcmc_rank_overlay(arr,
regex_pars = "beta",
n_bins = 15,
ref_line = TRUE,
ref_interval = TRUE,
interval_args = list(width = 0.8, alpha = 0.1))
)
})

test_that("mcmc_rank_* interval_args get validated", {
expect_error(
mcmc_rank_overlay(arr,
regex_pars = "beta",
n_bins = 15,
ref_line = TRUE,
ref_interval = TRUE,
interval_args = list(with = 0.8, alpha = 0.1)), # intended typo
"is not TRUE"
)
})



# displaying divergences in traceplot -------------------------------------
test_that("mcmc_trace 'np' argument works", {
Expand Down