Skip to content

Commit

Permalink
Merge pull request #1295 from paul-buerkner/logistic-normal
Browse files Browse the repository at this point in the history
Implement the logistic-normal family
  • Loading branch information
paul-buerkner authored Feb 7, 2022
2 parents a3e1561 + 7b115b7 commit 3164328
Show file tree
Hide file tree
Showing 31 changed files with 605 additions and 174 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.16.6
Date: 2022-02-01
Version: 2.16.7
Date: 2022-02-06
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
role = c("aut", "cre")),
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ export(dhurdle_negbinomial)
export(dhurdle_poisson)
export(dinv_gaussian)
export(dirichlet)
export(dlogistic_normal)
export(dmulti_normal)
export(dmulti_student_t)
export(do_call)
Expand Down Expand Up @@ -434,6 +435,7 @@ export(launch_shinystan)
export(lf)
export(log_lik)
export(log_posterior)
export(logistic_normal)
export(logit_scaled)
export(logm1)
export(lognormal)
Expand Down Expand Up @@ -548,6 +550,7 @@ export(rfrechet)
export(rgen_extreme_value)
export(rhat)
export(rinv_gaussian)
export(rlogistic_normal)
export(rmulti_normal)
export(rmulti_student_t)
export(rows2labels)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# brms 2.16.3++

### New Features

* Add family `logistic_normal` for simplex responses. (#1274)

### Other changes

* Argument `brms_seed` has been added to `get_refmodel.brmsfit()`. (#1287)
Expand Down
135 changes: 75 additions & 60 deletions R/brmsfit-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,63 +451,93 @@ get_theta <- function(prep, i = NULL) {

# get posterior draws of multivariate mean vectors
# only used in multivariate models with 'rescor'
# and in univariate models with multiple 'mu' pars such as logistic_normal
get_Mu <- function(prep, i = NULL) {
stopifnot(is.mvbrmsprep(prep))
Mu <- prep$mvpars$Mu
if (is.null(Mu)) {
Mu <- lapply(prep$resps, get_dpar, "mu", i = i)
if (length(i) == 1L) {
Mu <- do_call(cbind, Mu)
} else {
# keep correct dimension even if data has only 1 row
Mu <- lapply(Mu, as.matrix)
Mu <- abind::abind(Mu, along = 3)
}
is_mv <- is.mvbrmsprep(prep)
if (is_mv) {
Mu <- prep$mvpars$Mu
} else {
stopifnot(is.brmsprep(prep))
Mu <- prep$dpars$Mu
}
if (!is.null(Mu)) {
stopifnot(!is.null(i))
Mu <- slice_col(Mu, i)
return(Mu)
}
if (is_mv) {
Mu <- lapply(prep$resps, get_dpar, "mu", i = i)
} else {
mu_dpars <- str_subset(names(prep$dpars), "^mu")
Mu <- lapply(mu_dpars, get_dpar, prep = prep, i = i)
}
if (length(i) == 1L) {
Mu <- do_call(cbind, Mu)
} else {
# keep correct dimension even if data has only 1 row
Mu <- lapply(Mu, as.matrix)
Mu <- abind::abind(Mu, along = 3)
}
Mu
}

# get posterior draws of residual covariance matrices
# only used in multivariate models with 'rescor'
get_Sigma <- function(prep, i = NULL) {
stopifnot(is.mvbrmsprep(prep))
Sigma <- prep$mvpars$Sigma
if (is.null(Sigma)) {
stopifnot(!is.null(prep$mvpars$rescor))
sigma <- named_list(names(prep$resps))
for (j in seq_along(sigma)) {
sigma[[j]] <- get_dpar(prep$resps[[j]], "sigma", i = i)
sigma[[j]] <- add_sigma_se(sigma[[j]], prep$resps[[j]], i = i)
}
is_matrix <- ulapply(sigma, is.matrix)
if (!any(is_matrix)) {
# happens if length(i) == 1 or if no sigma was predicted
sigma <- do_call(cbind, sigma)
Sigma <- get_cov_matrix(sigma, prep$mvpars$rescor)
} else {
for (j in seq_along(sigma)) {
# bring all sigmas to the same dimension
if (!is_matrix[j]) {
sigma[[j]] <- array(sigma[[j]], dim = dim_mu(prep))
}
}
nsigma <- length(sigma)
sigma <- abind(sigma, along = 3)
Sigma <- array(dim = c(dim_mu(prep), nsigma, nsigma))
for (n in seq_len(ncol(Sigma))) {
Sigma[, n, , ] <- get_cov_matrix(sigma[, n, ], prep$mvpars$rescor)
}
}
# and in univariate models with multiple 'mu' pars such as logistic_normal
get_Sigma <- function(prep, i = NULL, cor_name = NULL) {
is_mv <- is.mvbrmsprep(prep)
if (is_mv) {
cor_name <- "rescor"
Sigma <- prep$mvpars$Sigma
} else {
stopifnot(is.brmsprep(prep))
cor_name <- as_one_character(cor_name)
Sigma <- prep$dpars$Sigma
}
if (!is.null(Sigma)) {
# already computed before
stopifnot(!is.null(i))
ldim <- length(dim(Sigma))
stopifnot(ldim %in% 3:4)
if (ldim == 4L) {
Sigma <- slice_col(Sigma, i)
}
return(Sigma)
}
if (is_mv) {
cors <- prep$mvpars[[cor_name]]
sigma <- named_list(names(prep$resps))
for (j in seq_along(sigma)) {
sigma[[j]] <- get_dpar(prep$resps[[j]], "sigma", i = i)
sigma[[j]] <- add_sigma_se(sigma[[j]], prep$resps[[j]], i = i)
}
} else {
cors <- prep$dpars[[cor_name]]
sigma_names <- str_subset(names(prep$dpars), "^sigma")
sigma <- named_list(sigma_names)
for (j in seq_along(sigma)) {
sigma[[j]] <- get_dpar(prep, sigma_names[j], i = i)
sigma[[j]] <- add_sigma_se(sigma[[j]], prep, i = i)
}
}
is_matrix <- ulapply(sigma, is.matrix)
if (!any(is_matrix)) {
# happens if length(i) == 1 or if no sigma was predicted
sigma <- do_call(cbind, sigma)
Sigma <- get_cov_matrix(sigma, cors)
} else {
for (j in seq_along(sigma)) {
# bring all sigmas to the same dimension
if (!is_matrix[j]) {
sigma[[j]] <- array(sigma[[j]], dim = dim_mu(prep))
}
}
nsigma <- length(sigma)
sigma <- abind(sigma, along = 3)
Sigma <- array(dim = c(dim_mu(prep), nsigma, nsigma))
for (n in seq_len(ncol(Sigma))) {
Sigma[, n, , ] <- get_cov_matrix(sigma[, n, ], cors)
}
}
Sigma
}
Expand Down Expand Up @@ -581,36 +611,21 @@ subset_thres <- function(prep, i) {
}

# helper function of 'get_dpar' to decide if
# the link function should be applied direclty
# the link function should be applied directly
apply_dpar_inv_link <- function(dpar, family) {
!(has_joint_link(family) && dpar_class(dpar, family) == "mu")
}

# insert zeros for the predictor term of the reference category
# in categorical-like models using the softmax response function
insert_refcat <- function(eta, family) {
stopifnot(is.array(eta), is.brmsfamily(family))
if (!conv_cats_dpars(family) || isNA(family$refcat)) {
return(eta)
}
insert_refcat <- function(eta, refcat = 1) {
stopifnot(is.array(eta))
refcat <- as_one_integer(refcat)
# need to add zeros for the reference category
ndim <- length(dim(eta))
dim_noncat <- dim(eta)[-ndim]
zeros_arr <- array(0, dim = c(dim_noncat, 1))
if (is.null(family$refcat) || is.null(family$cats)) {
# no information on the categories provided:
# use the first category as the reference
return(abind::abind(zeros_arr, eta))
}
ncat <- length(family$cats)
stopifnot(identical(dim(eta)[ndim], ncat - 1L))
if (is.null(dimnames(eta)[[ndim]])) {
dimnames(eta)[ndim] <-
list(paste0("mu", setdiff(family$cats, family$refcat)))
}
dimnames(zeros_arr)[ndim] <- list(paste0("mu", family$refcat))
iref <- match(family$refcat, family$cats)
before <- seq_len(iref - 1)
before <- seq_len(refcat - 1)
after <- setdiff(seq_dim(eta, ndim), before)
abind::abind(
slice(eta, ndim, before, drop = FALSE),
Expand Down
18 changes: 11 additions & 7 deletions R/brmsformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -1320,14 +1320,18 @@ validate_formula.brmsformula <- function(
}
predcats <- setdiff(out$family$cats, out$family$refcat)
}
mu_dpars <- make_stan_names(paste0("mu", predcats))
if (any(duplicated(mu_dpars))) {
stop2("Invalid response category names. Please avoid ",
"using any special characters in the names.")
multi_dpars <- valid_dpars(out$family, multi = TRUE)
# 'rev' so that mu comes last but appears first in the end
for (dp in rev(multi_dpars)) {
dp_dpars <- make_stan_names(paste0(dp, predcats))
if (any(duplicated(dp_dpars))) {
stop2("Invalid response category names. Please avoid ",
"using any special characters in the names.")
}
old_dp_dpars <- str_subset(out$family$dpars, paste0("^", dp))
out$family$dpars <- setdiff(out$family$dpars, old_dp_dpars)
out$family$dpars <- union(dp_dpars, out$family$dpars)
}
old_mu_dpars <- str_subset(out$family$dpars, "^mu")
out$family$dpars <- setdiff(out$family$dpars, old_mu_dpars)
out$family$dpars <- union(mu_dpars, out$family$dpars)
}

# incorporate deprecated arguments
Expand Down
4 changes: 2 additions & 2 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ data_response.brmsterms <- function(x, data, check_response = TRUE,
stop2("This model requires a response matrix.")
}
}
if (is_dirichlet(x$family)) {
if (is_simplex(x$family)) {
if (!is_equal(rowSums(out$Y), rep(1, nrow(out$Y)))) {
stop2("Response values in dirichlet models must sum to 1.")
stop2("Response values in simplex models must sum to 1.")
}
}
ybounds <- family_info(x$family, "ybounds")
Expand Down
56 changes: 45 additions & 11 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,45 @@ rmulti_student_t <- function(n, df, mu, Sigma, check = FALSE) {
sweep(draws, 2, mu, "+")
}

#' The (Multivariate) Logistic Normal Distribution
#'
#' Density function and random generation for the (multivariate) logistic normal
#' distribution with latent mean vector \code{mu} and covariance matrix \code{Sigma}.
#'
#' @name LogisticNormal
#'
#' @inheritParams StudentT
#' @param x Vector or matrix of quantiles. If \code{x} is a matrix,
#' each row is taken to be a quantile.
#' @param mu Mean vector with length equal to the number of dimensions.
#' @param Sigma Covariance matrix.
#' @param refcat A single integer indicating the reference category.
#' Defaults to \code{1}.
#' @param check Logical; Indicates whether several input checks
#' should be performed. Defaults to \code{FALSE} to improve
#' efficiency.
#'
#' @export
dlogistic_normal <- function(x, mu, Sigma, refcat = 1, log = FALSE,
check = FALSE) {
if (is.vector(x) || length(dim(x)) == 1L) {
x <- matrix(x, ncol = length(x))
}
lx <- link_categorical(x, refcat)
out <- dmulti_normal(lx, mu, Sigma, log = TRUE) - rowSums(log(x))
if (!log) {
out <- exp(out)
}
out
}

#' @rdname LogisticNormal
#' @export
rlogistic_normal <- function(n, mu, Sigma, refcat = 1, check = FALSE) {
out <- rmulti_normal(n, mu, Sigma, check = check)
inv_link_categorical(out, refcat = refcat)
}

#' The Skew-Normal Distribution
#'
#' Density, distribution function, and random generation for the
Expand Down Expand Up @@ -1960,7 +1999,7 @@ dcategorical <- function(x, eta, log = FALSE) {
if (length(dim(eta)) != 2L) {
stop2("eta must be a numeric vector or matrix.")
}
out <- inv_link_categorical(eta, refcat_obj = NULL, log = log)
out <- inv_link_categorical(eta, log = log, refcat = NULL)
out[, x, drop = FALSE]
}

Expand All @@ -1972,23 +2011,18 @@ dcategorical <- function(x, eta, log = FALSE) {
# dcategorical()) or an array (S x N x `ncat` or S x N x `ncat - 1` (depending
# on `refcat_obj`)) containing the same values as the matrix just described,
# but for N observations.
# @param refcat_obj Either the string "first", an object of class "brmsfamily",
# or NULL. If "first", then the first category is used as reference and
# corresponding values are inserted into `x`. If an object of class
# "brmsfamily", then passed to insert_refcat() which is used to insert values
# for the reference category into `x`. If NULL, `x` is not modified at all.
# @param refcat Integer indicating the reference category to be inserted in 'x'.
# If NULL, `x` is not modified at all.
# @param log Logical (length 1) indicating whether to log the return value.
#
# @return If `x` is a matrix, then a matrix (S x `ncat`, with S denoting the
# number of posterior draws and `ncat` denoting the number of response
# categories) containing the values of the inverse-link function applied to
# `x`. If `x` is an array, then an array (S x N x `ncat`) containing the same
# values as the matrix just described, but for N observations.
inv_link_categorical <- function(x, refcat_obj = "first", log = FALSE) {
if (is_equal(refcat_obj, "first")) {
x <- insert_refcat(x, family = categorical()) # The link does not matter.
} else if (!is.null(refcat_obj)) {
x <- insert_refcat(x, family = refcat_obj)
inv_link_categorical <- function(x, refcat = 1, log = FALSE) {
if (!is.null(refcat)) {
x <- insert_refcat(x, refcat = refcat)
}
if (log) {
out <- log_softmax(x)
Expand Down
4 changes: 2 additions & 2 deletions R/exclude_pars.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ exclude_pars.mvbrmsterms <- function(x, save_pars, ...) {

#' @export
exclude_pars.brmsterms <- function(x, save_pars, ...) {
out <- character(0)
out <- "Lncor"
resp <- usc(combine_prefix(x))
if (!save_pars$all) {
par_classes <- c("ordered_Intercept", "fixed_Intercept", "theta")
par_classes <- c("ordered_Intercept", "fixed_Intercept", "theta", "Llncor")
c(out) <- paste0(par_classes, resp)
}
for (dp in names(x$dpars)) {
Expand Down
Loading

0 comments on commit 3164328

Please sign in to comment.