Skip to content

Commit

Permalink
Add multinomial distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Mar 3, 2023
1 parent 189c366 commit e10ecd1
Show file tree
Hide file tree
Showing 19 changed files with 695 additions and 311 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# dynamite 1.3.0

* Added support for Student's t-distribution via `"student"` family in `obs`.
* Added support for the multinomial distribution via `"multinomial"` family
in `obs`.

# dynamite 1.2.1

Expand Down
30 changes: 21 additions & 9 deletions R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
attr(x$stan$responses, "resp_class")[[response]],
"levels"
)[-1L]
channel <- get_channel(x, response)
if (is_multinomial(channel$family)) {
category <- channel$y[-1L]
}
if (is.null(category)) {
category <- NA
}
Expand Down Expand Up @@ -280,6 +284,14 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
out
}

get_channel <- function(x, response) {
if (is.null(x$stan$channel_vars[[response]])) {
x$stan$channel_group_vars[[response]]
} else {
x$stan$channel_vars[[response]]
}
}

#' Construct a Data Table for a Parameter Type from a `dynamitefit` Object
#'
#' Arguments for all as_data_frame_type functions are documented here.
Expand Down Expand Up @@ -329,13 +341,13 @@ as_data_table_corr_nu <- function(x, draws, n_draws, ...) {
#' @noRd
as_data_table_nu <- function(x, draws, n_draws, response, ...) {
icpt <- ifelse_(
x$stan$channel_vars[[response]]$has_random_intercept,
get_channel(x, response)$has_random_intercept,
"alpha",
NULL
)
var_names <- paste0(
"nu_", response, "_",
c(icpt, names(x$stan$channel_vars[[response]]$J_random))
c(icpt, names(get_channel(x, response)$J_random))
)
n_vars <- length(var_names)
groups <- sort(unique(x$data[[x$group_var]]))
Expand All @@ -354,7 +366,7 @@ as_data_table_alpha <- function(x, draws, n_draws,
n_cat <- length(category)
fixed <- x$stan$fixed
all_time_points <- sort(unique(x$data[[x$time_var]]))
if (x$stan$channel_vars[[response]]$has_varying_intercept) {
if (get_channel(x, response)$has_varying_intercept) {
time_points <- ifelse_(
include_fixed,
all_time_points,
Expand Down Expand Up @@ -390,7 +402,7 @@ as_data_table_alpha <- function(x, draws, n_draws,
as_data_table_beta <- function(x, draws, n_draws, response, category, ...) {
var_names <- paste0(
"beta_", response, "_",
names(x$stan$channel_vars[[response]]$J_fixed)
names(get_channel(x, response)$J_fixed)
)
n_vars <- length(var_names)
data.table::data.table(
Expand All @@ -409,7 +421,7 @@ as_data_table_delta <- function(x, draws, n_draws,
all_time_points <- sort(unique(x$data[[x$time_var]]))
var_names <- paste0(
"delta_", response, "_",
names(x$stan$channel_vars[[response]]$J_varying)
names(get_channel(x, response)$J_varying)
)
n_vars <- length(var_names)
time_points <- ifelse_(
Expand Down Expand Up @@ -441,7 +453,7 @@ as_data_table_delta <- function(x, draws, n_draws,
as_data_table_tau <- function(x, draws, n_draws, response, ...) {
var_names <- paste0(
"tau_", response, "_",
names(x$stan$channel_vars[[response]]$J_varying)
names(get_channel(x, response)$J_varying)
)
data.table::data.table(
parameter = rep(var_names, each = n_draws),
Expand All @@ -456,7 +468,7 @@ as_data_table_omega <- function(x, draws, n_draws, response, category, ...) {
D <- x$stan$sampling_vars$D
var_names <- paste0(
"omega_", response, "_",
names(x$stan$channel_vars[[response]]$J_varying)
names(get_channel(x, response)$J_varying)
)
k <- length(var_names)
data.table::data.table(
Expand Down Expand Up @@ -504,13 +516,13 @@ as_data_table_sigma <- function(draws, response, ...) {
#' @noRd
as_data_table_sigma_nu <- function(x, draws, n_draws, response, ...) {
icpt <- ifelse_(
x$stan$channel_vars[[response]]$has_random_intercept,
get_channel(x, response)$has_random_intercept,
"alpha",
NULL
)
var_names <- paste0(
"sigma_nu_", response, "_",
c(icpt, names(x$stan$channel_vars[[response]]$J_random))
c(icpt, names(get_channel(x, response)$J_random))
)
data.table::data.table(
parameter = rep(var_names, each = n_draws),
Expand Down
5 changes: 2 additions & 3 deletions R/default_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,9 @@ default_priors <- function(y, channel, mean_gamma, sd_gamma, mean_y, sd_y) {
#' Standard deviation of the explanatory variables at time `fixed + 1`.
#' @param resp_class \[`character(1)`]\cr Class of the response variable.
#' @noRd
default_priors_categorical <- function(y, channel, sd_x, resp_class) {
S_y <- length(attr(resp_class, "levels"))
default_priors_categorical <- function(y, channel, sd_x, S_y, resp_levels) {
# remove the first level which acts as reference
resp_levels <- attr(resp_class, "levels")[-1]
resp_levels <- resp_levels[-1L]
sd_gamma <- signif(pmax(2 / sd_x, 1), 2)
priors <- list()
if (channel$has_fixed_intercept || channel$has_varying_intercept) {
Expand Down
13 changes: 11 additions & 2 deletions R/dynamiteformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ parse_formula <- function(x, original, family) {
formula_parts <- strsplit(formula_str, "|", fixed = TRUE)[[1L]]
n_formulas <- length(formula_parts)
n_responses <- length(responses)
mn <- is_multinomial(family)
mvf <- is_multivariate(family)
mvc <- n_responses > 1L
stopifnot_(
Expand All @@ -294,9 +295,17 @@ parse_formula <- function(x, original, family) {
)
stopifnot_(
!mvc || n_formulas == n_responses || n_formulas == 1L,
"Number of component formulas ({n_formulas}) must be 1 or
the number of dimensions: {n_responses}."
c(
"Number of component formulas must be 1 or
the number of dimensions: {n_responses}",
`x` = "{n_formulas} formulas were provided."
)
)
#stopifnot_(
# (mvf && !mn) || (mn && n_formulas == 1L),
# "Only a single formula must be provided for
# univariate and multinomial channels."
#)
formula_parts <- ifelse_(
n_formulas == 1L,
rep(formula_parts, n_responses),
Expand Down
11 changes: 10 additions & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ is_supported <- function(name) {
#' @param x \[`dynamitefamily`]\cr A family object.
#' @noRd
is_multivariate <- function(x) {
x$name %in% c("mvgaussian")
x$name %in% c("mvgaussian", "multinomial")
}

supported_families <- c(
"binomial",
"bernoulli", # separate as Stan has more efficient pmf for it
"categorical",
"multinomial",
"negbin",
"gaussian",
"mvgaussian",
Expand All @@ -55,6 +56,14 @@ supported_families <- c(
"student"
)

#' Test If Multivariate Family Uses Univariate Components
#'
#' @param x \[`dynamitefamily`]\cr A family object.
#' @noRd
has_univariate <- function(x) {
x$name %in% setdiff(supported_families, "multinomial")
}

#' Get Univariate Version of a Multivariate Family
#'
#' @param x \[`dynamitefamily`]\cr A family object.
Expand Down
Loading

0 comments on commit e10ecd1

Please sign in to comment.