Skip to content

Commit

Permalink
Merge pull request #169 from stan-dev/issue-168
Browse files Browse the repository at this point in the history
mcmc_areas() fixes
  • Loading branch information
jgabry committed Oct 23, 2018
2 parents 645c34a + 3689fc7 commit 7bb24d3
Show file tree
Hide file tree
Showing 73 changed files with 6,502 additions and 5,193 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ Suggests:
shinystan (>= 2.3.0),
testthat (>= 2.0.0),
vdiffr
RoxygenNote: 6.0.1
RoxygenNote: 6.1.0
VignetteBuilder: knitr
Encoding: UTF-8
22 changes: 22 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@
gains an argument `iter1` which can be used to label the traceplot starting
from the first iteration after warmup. (#14, #155, @mcol)

* [`mcmc_areas()`](http://mc-stan.org/bayesplot/reference/MCMC-intervals.html)
gains an argument `area_method` which controls how to draw the density
curves. The default `"equal area"` constrains the heights so that the curves
have the same area. As a result, a narrow interval will appear as a spike
of density, while a wide, uncertain interval is spread thin over the _x_ axis.
Alternatively `"equal height"` will set the maximum height on each curve to
the same value. This works well when the intervals are about the same width.
Otherwise, that wide, uncertain interval will dominate the visual space
compared to a narrow, less uncertain interval. A compromise between the two is
`"scaled height"` which scales the curves from `"equal height"` using
`height * sqrt(height)`. (#163, #169)

* `mcmc_areas()` correctly plots density curves where the point estimate
does not include the highest point of the density curve.
(#168, #169, @jtimonen)

* `mcmc_areas_ridges()` draws the vertical line at *x* = 0 over the curves so
that it is always visible.

* `mcmc_intervals()` and `mcmc_areas()` raise a warning if `prob_outer` is ever
less than `prob`. It sorts these two values into the correct order. (#138)

* MCMC parameter names are now *always* converted to factors prior to
plotting. We use factors so that the order of parameters in a plot matches
the order of the parameters in the original MCMC data. This change fixes a
Expand Down
120 changes: 95 additions & 25 deletions R/mcmc-intervals.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
#' @param prob_outer The probability mass to include in the outer interval. The
#' default is \code{0.9} for \code{mcmc_intervals} (90\% interval) and
#' \code{1} for \code{mcmc_areas} and for \code{mcmc_areas_ridges}.
#' @param area_method How to constrain the areas in \code{mcmc_areas}. The
#' default is \code{"equal area"}, setting the density curves to have the same
#' area. With \code{"equal height"}, the curves are scaled so that the highest
#' points across the curves are the same height. The method \code{"scaled
#' height"} tries a compromise between to the two: the heights from
#' \code{"equal height"} are scaled using \code{height*sqrt(height)}
#' @param point_est The point estimate to show. Either \code{"median"} (the
#' default), \code{"mean"}, or \code{"none"}.
#' @param rhat An optional numeric vector of \eqn{\hat{R}}{Rhat} estimates, with
Expand Down Expand Up @@ -76,6 +82,7 @@
#' fake_rhat_values <- c(1, 1.07, 1.3, 1.01, 1.15, 1.005)
#' mcmc_intervals(x, rhat = fake_rhat_values)
#'
#' # get the dataframe that is used in the plotting functions
#' mcmc_intervals_data(x)
#' mcmc_intervals_data(x, rhat = fake_rhat_values)
#' mcmc_areas_data(x, pars = "alpha")
Expand All @@ -85,6 +92,27 @@
#' p + legend_move("bottom")
#' p + legend_move("none") # or p + legend_none()
#'
#' # Different area calculations
#' b3 <- c("beta[1]", "beta[2]", "beta[3]")
#'
#' mcmc_areas(x, pars = b3, area_method = "equal area") +
#' ggplot2::labs(
#' title = "Curves have same area",
#' subtitle =
#' "A wide, uncertain interval is spread thin when areas are equal")
#'
#' mcmc_areas(x, pars = b3, area_method = "equal height") +
#' ggplot2::labs(
#' title = "Curves have same maximum height",
#' subtitle =
#' "Local curvature is clearer but more uncertain curves use more area")
#'
#' mcmc_areas(x, pars = b3, area_method = "scaled height") +
#' ggplot2::labs(
#' title = "Same maximum heights but heights scaled by square-root",
#' subtitle =
#' "Compromise: Local curvature is accentuated and less area is used")
#'
#' \donttest{
#' # apply transformations
#' mcmc_intervals(
Expand Down Expand Up @@ -223,6 +251,7 @@ mcmc_areas <- function(x,
regex_pars = character(),
transformations = list(),
...,
area_method = c("equal area", "equal height", "scaled height"),
prob = 0.5,
prob_outer = 1,
point_est = c("median", "mean", "none"),
Expand All @@ -232,6 +261,8 @@ mcmc_areas <- function(x,
kernel = NULL,
n_dens = NULL) {
check_ignored_arguments(...)
area_method <- match.arg(area_method)

data <- mcmc_areas_data(
x, pars, regex_pars, transformations,
prob = prob, prob_outer = prob_outer,
Expand Down Expand Up @@ -269,6 +300,14 @@ mcmc_areas <- function(x,
rlang::syms(c("parameter"))
}

if (area_method == "equal height") {
dens_col = ~ scaled_density
} else if (area_method == "scaled height") {
dens_col = ~ scaled_density * sqrt(scaled_density)
} else {
dens_col = ~ plotting_density
}

datas$bottom <- datas$outer %>%
group_by(!!! groups) %>%
summarise(ll = min(.data$x), hh = max(.data$x)) %>%
Expand All @@ -279,16 +318,16 @@ mcmc_areas <- function(x,
data = datas$bottom
)
args_inner <- list(
mapping = aes_(height = ~ density),
mapping = aes_(height = dens_col, scale = ~ .9),
data = datas$inner
)
args_point <- list(
mapping = aes_(height = ~ density),
mapping = aes_(height = dens_col, scale = ~ .9),
data = datas$point,
color = NA
)
args_outer <- list(
mapping = aes_(height = ~ density),
mapping = aes_(height = dens_col, scale = ~ .9),
fill = NA
)

Expand All @@ -314,9 +353,15 @@ mcmc_areas <- function(x,
}

layer_bottom <- do.call(geom_segment, args_bottom)
layer_inner <- do.call(geom_area_ridges, args_inner)
layer_point <- do.call(geom_area_ridges, args_point)
layer_outer <- do.call(geom_area_ridges, args_outer)
layer_inner <- do.call(ggridges::geom_ridgeline, args_inner)
layer_outer <- do.call(ggridges::geom_ridgeline, args_outer)

point_geom <- if (no_point_est) {
geom_ignore
} else {
ggridges::geom_ridgeline
}
layer_point <- do.call(point_geom, args_point)

# Do something or add an invisible layer
if (color_by_rhat) {
Expand All @@ -336,8 +381,9 @@ mcmc_areas <- function(x,
layer_bottom +
scale_color +
scale_fill +
scale_y_discrete(limits = unique(rev(data$parameter)),
expand = c(0.05, .6)) +
scale_y_discrete(
limits = unique(rev(data$parameter)),
expand = expand_scale(add = c(0, .1), mult = c(.1, .3))) +
xlim(x_lim) +
bayesplot_theme_get() +
legend_move(ifelse(color_by_rhat, "top", "none")) +
Expand Down Expand Up @@ -402,7 +448,7 @@ mcmc_areas_ridges <- function(x,
# Draw each ridgeline from top the bottom
layer_list_inner <- list()
par_draw_order <- levels(unique(data$parameter))
bg <- theme_get()[["panel.background"]][["fill"]] %||% "white"
bg <- bayesplot_theme_get()[["panel.background"]][["fill"]] %||% "white"

for (par_num in seq_along(unique(data$parameter))) {
# Basically, draw the current ridgeline normally, but draw all the ones
Expand All @@ -429,11 +475,11 @@ mcmc_areas_ridges <- function(x,

ggplot(datas$outer) +
aes_(x = ~ x, y = ~ parameter) +
layer_vertical_line +
layer_outer +
scale_y_discrete(limits = unique(rev(data$parameter)),
expand = c(0.05, .6)) +
layer_list_inner +
layer_vertical_line +
scale_fill_identity() +
scale_color_identity() +
xlim(x_lim) +
Expand All @@ -458,9 +504,9 @@ mcmc_intervals_data <- function(x,
point_est = c("median", "mean", "none"),
rhat = numeric()) {
check_ignored_arguments(...)
both_probs <- sort(c(prob, prob_outer))
prob <- both_probs[1]
prob_outer <- both_probs[2]
probs <- check_interval_widths(prob, prob_outer)
prob <- probs[1]
prob_outer <- probs[2]

x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
x <- merge_chains(x)
Expand Down Expand Up @@ -540,7 +586,7 @@ mcmc_areas_data <- function(x,
adjust = NULL,
kernel = NULL,
n_dens = NULL) {
probs <- sort(c(prob, prob_outer))
probs <- check_interval_widths(prob, prob_outer)

# First compute normal intervals so we know the width of the data, point
# estimates, and have prepared rhat values.
Expand All @@ -563,17 +609,25 @@ mcmc_areas_data <- function(x,

# Compute the density intervals
data_inner <- data_long %>%
compute_column_density(.data$parameter, .data$value,
interval_width = probs[1],
bw = bw, adjust = adjust, kernel = kernel,
n_dens = n_dens) %>%
compute_column_density(
group_vars = .data$parameter,
value_var = .data$value,
interval_width = probs[1],
bw = bw,
adjust = adjust,
kernel = kernel,
n_dens = n_dens) %>%
mutate(interval = "inner")

data_outer <- data_long %>%
compute_column_density(.data$parameter, .data$value,
interval_width = probs[2],
bw = bw, adjust = adjust, kernel = kernel,
n_dens = n_dens) %>%
compute_column_density(
group_vars = .data$parameter,
value_var = .data$value,
interval_width = probs[2],
bw = bw,
adjust = adjust,
kernel = kernel,
n_dens = n_dens) %>%
mutate(interval = "outer")

# Point estimates will be intervals that take up .8% of the x-axis
Expand Down Expand Up @@ -601,8 +655,9 @@ mcmc_areas_data <- function(x,
left_join(data_inner, by = "parameter") %>%
group_by(.data$parameter) %>%
dplyr::filter(abs(.data$center - .data$x) <= half_point_width) %>%
mutate(interval_width = 0,
interval = "point") %>%
mutate(
interval_width = 0,
interval = "point") %>%
select(-.data$center, .data$m) %>%
ungroup()

Expand All @@ -612,7 +667,10 @@ mcmc_areas_data <- function(x,
}

data <- dplyr::bind_rows(data_inner, data_outer, points) %>%
select(one_of("parameter", "interval", "interval_width", "x", "density"))
select(one_of("parameter", "interval", "interval_width",
"x", "density", "scaled_density")) %>%
# Density scaled so the highest in entire dataframe has height 1
mutate(plotting_density = .data$density / max(.data$density))

if (rlang::has_name(intervals, "rhat_value")) {
rhat_info <- intervals %>%
Expand Down Expand Up @@ -719,3 +777,15 @@ compute_interval_density <- function(x, interval_width = 1, n_dens = 1024,
scaled_density = dens$y / max(dens$y, na.rm = TRUE)
)
}

check_interval_widths <- function(prob, prob_outer) {
if (prob_outer < prob) {
x <- sprintf(
"`prob_outer` (%s) is less than `prob` (%s)\n... %s",
prob_outer,
prob,
"Swapping the values of `prob_outer` and `prob`")
warning(x, call. = FALSE)
}
sort(c(prob, prob_outer))
}
5 changes: 3 additions & 2 deletions man/MCMC-combos.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/MCMC-diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions man/MCMC-distributions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7bb24d3

Please sign in to comment.