Skip to content

Commit

Permalink
logistic_normal: enable to predict 'sigma'
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Feb 6, 2022
1 parent cddbf13 commit 08beee6
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 38 deletions.
10 changes: 2 additions & 8 deletions R/brmsformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -1320,14 +1320,8 @@ validate_formula.brmsformula <- function(
}
predcats <- setdiff(out$family$cats, out$family$refcat)
}
dpars_data <- "mu"
if (is_logistic_normal(out$family)) {
# "sigma" parameters are also determined from the data
# evaluate "mu" last so that it is listed first in the end
# TODO: generalize interface to store dpars_data in the family?
dpars_data <- c("sigma", dpars_data)
}
for (dp in dpars_data) {
multi_dpars <- valid_dpars(out$family, multi = TRUE)
for (dp in multi_dpars) {
dp_dpars <- make_stan_names(paste0(dp, predcats))
if (any(duplicated(dp_dpars))) {
stop2("Invalid response category names. Please avoid ",
Expand Down
32 changes: 21 additions & 11 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
)
out[names(family_info)] <- family_info
class(out) <- c("brmsfamily", "family")
for (dp in valid_dpars(out)) {
all_valid_dpars <- c(valid_dpars(out), valid_dpars(out, multi = TRUE))
for (dp in all_valid_dpars) {
alink <- as.character(aux_links[[paste0("link_", dp)]])
if (length(alink)) {
alink <- as_one_character(alink)
Expand Down Expand Up @@ -639,10 +640,11 @@ dirichlet2 <- function(link = "log") {

#' @rdname brmsfamily
#' @export
logistic_normal <- function(link = "identity", refcat = NULL) {
# TODO: allow to pass and store 'link_sigma' as well
logistic_normal <- function(link = "identity", link_sigma = "log",
refcat = NULL) {
slink <- substitute(link)
.brmsfamily("logistic_normal", link = link, slink = slink, refcat = refcat)
.brmsfamily("logistic_normal", link = link, slink = slink,
link_sigma = link_sigma, refcat = refcat)
}

#' @rdname brmsfamily
Expand Down Expand Up @@ -1211,12 +1213,17 @@ valid_dpars <- function(family, ...) {
}

#' @export
valid_dpars.default <- function(family, ...) {
valid_dpars.default <- function(family, multi = FALSE, ...) {
if (!length(family)) {
return("mu")
}
family <- validate_family(family)
family_info(family, "dpars", ...)
if (multi) {
out <- family_info(family, "multi_dpars", ...)
} else {
out <- family_info(family, "dpars", ...)
}
out
}

#' @export
Expand Down Expand Up @@ -1258,13 +1265,16 @@ dpar_class <- function(dpar, family = NULL) {
out <- sub("[[:digit:]]*$", "", dpar)
if (!is.null(family)) {
# TODO: avoid these special cases by changing naming conventions
if (conv_cats_dpars(family) && grepl("^mu", out)) {
# perhaps add a protected "C" before category names
# and a protected "M" for mixture components
if (conv_cats_dpars(family)) {
# categorical-like models have non-integer suffixes
# that will not be caught by the standard procedure
out <- "mu"
}
if (is_logistic_normal(family) && grepl("^sigma", out)) {
out <- "sigma"
multi_dpars <- valid_dpars(family, multi = TRUE)
for (dp in multi_dpars) {
sel <- grepl(paste0("^", dp), out)
out[sel] <- dp
}
}
}
out
Expand Down
15 changes: 10 additions & 5 deletions R/family-lists.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
.family_categorical <- function() {
list(
links = "logit",
dpars = NULL, # is determined based on the data
dpars = NULL,
multi_dpars = "mu", # size determined by the data
type = "int", ybounds = c(-Inf, Inf),
closed = c(NA, NA),
ad = c("weights", "subset", "index"),
Expand All @@ -84,7 +85,8 @@
.family_multinomial <- function() {
list(
links = "logit",
dpars = NULL, # is determined based on the data
dpars = NULL,
multi_dpars = "mu", # size determined by the data
type = "int", ybounds = c(-Inf, Inf),
closed = c(NA, NA),
ad = c("weights", "subset", "trials", "index"),
Expand All @@ -109,7 +111,8 @@
.family_dirichlet <- function() {
list(
links = "logit",
dpars = "phi", # more dpars are determined based on the data
dpars = "phi",
multi_dpars = "mu", # size determined by the data
type = "real", ybounds = c(0, 1),
closed = c(FALSE, FALSE),
ad = c("weights", "subset", "index"),
Expand All @@ -122,7 +125,8 @@
.family_dirichlet2 <- function() {
list(
links = c("log", "softplus", "squareplus", "identity", "logm1"),
dpars = NULL, # is fully determined based on the data
dpars = NULL,
multi_dpars = "mu", # size determined by the data
type = "real", ybounds = c(0, 1),
closed = c(FALSE, FALSE),
ad = c("weights", "subset", "index"),
Expand All @@ -135,7 +139,8 @@
.family_logistic_normal <- function() {
list(
links = "identity",
dpars = NULL, # is fully determined based on the data
dpars = NULL,
multi_dpars = c("mu", "sigma"), # size determined by the data
type = "real", ybounds = c(0, 1),
closed = c(FALSE, FALSE),
ad = c("weights", "subset", "index"),
Expand Down
28 changes: 16 additions & 12 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,16 @@ stan_log_lik_Y_name <- function(bterms) {
# @param reqn will the likelihood be wrapped in a loop over n?
# @param dpars optional names of distributional parameters to be prepared
# if not specified will prepare all distributional parameters
stan_log_lik_dpars <- function(bterms, reqn, resp = "", mix = "", dpars = NULL) {
stan_log_lik_dpars <- function(bterms, reqn, resp = "", mix = "", dpars = NULL,
multi = FALSE) {
if (is.null(dpars)) {
dpars <- paste0(valid_dpars(bterms), mix)
dpars <- paste0(valid_dpars(bterms, multi = multi), mix)
}
is_pred <- dpars %in% c("mu", names(bterms$dpars))
pred_dpars <- names(bterms$dpars)
if (multi) {
pred_dpars <- unique(dpar_class(pred_dpars, bterms))
}
is_pred <- dpars %in% pred_dpars
out <- paste0(dpars, resp, ifelse(reqn & is_pred, "[n]", ""))
named_list(dpars, out)
}
Expand Down Expand Up @@ -710,40 +715,39 @@ stan_log_lik_acat <- function(bterms, resp = "", mix = "",
stan_log_lik_categorical <- function(bterms, resp = "", mix = "", ...) {
stopifnot(bterms$family$link == "logit")
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu")
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", multi = TRUE)
sdist("categorical_logit", p$mu)
}

stan_log_lik_multinomial <- function(bterms, resp = "", mix = "", ...) {
stopifnot(bterms$family$link == "logit")
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu")
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", multi = TRUE)
sdist("multinomial_logit2", p$mu)
}

stan_log_lik_dirichlet <- function(bterms, resp = "", mix = "", ...) {
stopifnot(bterms$family$link == "logit")
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu")$mu
mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", multi = TRUE)$mu
reqn <- glue("phi{mix}") %in% names(bterms$dpars)
phi <- stan_log_lik_dpars(bterms, reqn, resp, mix, dpars = "phi")$phi
sdist("dirichlet_logit", mu, phi)
}

stan_log_lik_dirichlet2 <- function(bterms, resp = "", mix = "", ...) {
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu")$mu
mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", multi = TRUE)$mu
sdist("dirichlet", mu)
}

stan_log_lik_logistic_normal <- function(bterms, resp = "", mix = "", ...) {
stopifnot(bterms$family$link == "identity")
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu")$mu
sigma <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "sigma")$sigma
Llncor <- glue("Llncor{mix}{resp}")
refcat <- get_refcat(bterms$family, int = TRUE)
sdist("logistic_normal_cholesky_cor", mu, sigma, Llncor, refcat)
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, multi = TRUE)
p$Llncor <- glue("Llncor{mix}{resp}")
p$refcat <- get_refcat(bterms$family, int = TRUE)
sdist("logistic_normal_cholesky_cor", p$mu, p$sigma, p$Llncor, p$refcat)
}

stan_log_lik_ordinal <- function(bterms, resp = "", mix = "",
Expand Down
2 changes: 1 addition & 1 deletion R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ stan_dpar_transform <- function(bterms, threads, ...) {
stopifnot(length(families) == 1L)
predcats <- get_predcats(bterms$family)
sigma_dpars <- glue("sigma{predcats}")
reqn <- sigma_dpars %in% bterms$dpars
reqn <- sigma_dpars %in% names(bterms$dpars)
n <- ifelse(reqn, "[n]", "")
sigma_dpars <- glue("{sigma_dpars}{p}{n}")
ncatm1 <- glue("ncat{p}-1")
Expand Down
2 changes: 1 addition & 1 deletion man/brmsfamily.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 08beee6

Please sign in to comment.