Skip to content

Commit

Permalink
allow specifying 'family' in function 'brmsformula'
Browse files Browse the repository at this point in the history
Also clean up some related internal structures that are now behaving more consistently.
  • Loading branch information
paul-buerkner committed Jan 23, 2017
1 parent c08a05c commit ae29077
Show file tree
Hide file tree
Showing 20 changed files with 184 additions and 157 deletions.
9 changes: 5 additions & 4 deletions R/brm.R
Expand Up @@ -34,6 +34,7 @@
#' If not specified, default links are used.
#' For details of supported families see
#' \code{\link[brms:brmsfamily]{brmsfamily}}.
#' By default, a linear \code{gaussian} model is applied.
#' @param prior One or more \code{brmsprior} objects created by
#' \code{\link[brms:set_prior]{set_prior}} or related functions
#' and combined using the \code{c} method. A single \code{brmsprior}
Expand Down Expand Up @@ -324,7 +325,7 @@
#' @import methods
#' @import stats
#' @export
brm <- function(formula, data, family = gaussian(), prior = NULL,
brm <- function(formula, data, family = NULL, prior = NULL,
autocor = NULL, nonlinear = NULL,
threshold = c("flexible", "equidistant"),
cov_ranef = NULL, save_ranef = TRUE, save_mevars = FALSE,
Expand Down Expand Up @@ -352,7 +353,6 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
"Forking is now automatically applied when appropriate.")
}
dots[deprecated_brm_args()] <- NULL
check_brm_input(nlist(family, inits))
autocor <- check_autocor(autocor)
threshold <- match.arg(threshold)
algorithm <- match.arg(algorithm)
Expand All @@ -369,10 +369,11 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
x$fit <- rstan::get_stanmodel(x$fit)
} else { # build new model
# see validate.R and formula-helpers.R
family <- check_family(family)
formula <- amend_formula(formula, data = data, family = family,
nonlinear = nonlinear)
bterms <- parse_bf(formula, family = family, autocor = autocor)
family <- formula$family
check_brm_input(nlist(family))
bterms <- parse_bf(formula, autocor = autocor)
if (is.null(dots$data.name)) {
data.name <- substr(Reduce(paste, deparse(substitute(data))), 1, 50)
} else {
Expand Down
7 changes: 7 additions & 0 deletions R/brmsfit-helpers.R
Expand Up @@ -60,6 +60,13 @@ restructure <- function(x, rstr_summary = FALSE) {
# deprecated as of brms 1.4.0
class(x$autocor) <- "cor_fixed"
}
if (x$version <= "0.9.1") {
# update gaussian("log") to lognormal() family
nresp <- length(bterms$response)
if (is_old_lognormal(x$family, nresp = nresp, version = x$version)) {
object$family <- object$formula$family <- lognormal()
}
}
if (x$version <= "0.10.0.9000") {
if (length(bterms$nlpars)) {
# nlpar and group have changed positions
Expand Down
45 changes: 18 additions & 27 deletions R/brmsfit-methods.R
Expand Up @@ -818,11 +818,11 @@ formula.brmsfit <- function(x, ...) {

#' @export
family.brmsfit <- function(object, ...) {
if (is(object$family, "family")) {
# brms > 0.6.0
family <- object$family
} else {
if (is.character(object$family)) {
# brms <= 0.6.0
family <- brmsfamily(object$family, link = object$link)
} else {
family <- object$family
}
family
}
Expand Down Expand Up @@ -1913,27 +1913,28 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) {
if (missing(formula.)) {
dots$formula <- object$formula
} else {
recompile <- length(pforms(formula.)) > 0L
family <- get_arg("family", dots, formula., object)
nl <- get_arg("nl", formula., formula(object))
dots$formula <- bf(formula., family = family, nl = nl)
recompile <- length(pforms(dots$formula)) > 0L
if (is_nonlinear(object)) {
if (length(setdiff(all.vars(formula.), ".")) == 0L) {
dots$formula <- update(object$formula, formula., mode = "keep")
if (length(setdiff(all.vars(dots$formula$formula), ".")) == 0L) {
dots$formula <- update(object$formula, dots$formula, mode = "keep")
} else {
dots$formula <- update(object$formula, formula., mode = "replace")
dots$formula <- update(object$formula, dots$formula, mode = "replace")
message("Argument 'formula.' will completely replace the ",
"original formula in non-linear models.")
recompile <- TRUE
}
} else {
dots$formula <- as.formula(formula.)
mvars <- setdiff(all.vars(dots$formula), c(names(object$data), "."))
mvars <- all.vars(dots$formula$formula)
mvars <- setdiff(mvars, c(names(object$data), "."))
if (length(mvars) && is.null(newdata)) {
stop2("New variables found: ", paste(mvars, collapse = ", "),
"\nPlease supply your data again via argument 'newdata'")
}
dots$formula <- update(formula(object), dots$formula)
ee_old <- parse_bf(formula(object))
family <- get_arg("family", dots, object)
dots$formula <- amend_formula(dots$formula, family = family)
ee_new <- parse_bf(dots$formula)
# no need to recompile the model when changing fixed effects only
dont_change <- c("random", "gam", "cs", "mo", "me")
Expand All @@ -1945,22 +1946,13 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) {
ee_new[names(ee_new) %in% dont_change]) ||
is_equal(sort(c(n_old_fixef, n_new_fixef)), c(0L, 1L)) ||
length(ee_old$response) != length(ee_new$response) ||
length(pforms(formula.)) > 0L
length(pforms(formula.)) > 0L ||
!identical(dots$formula$family, family(object))
}
if (recompile) {
message("The desired formula changes require recompling the model")
}
}
# allow to change the non-linear part via argument 'nonlinear'
take_nl <- !is.null(dots$nonlinear) &&
(missing(formula.) || is.null(attr(formula., "nonlinear")))
if (take_nl) attr(dots$formula, "nonlinear") <- NULL
# update gaussian("log") to lognormal() family
resp <- parse_bf(object$formula, family = object$family)$response
if (is_old_lognormal(object$family, nresp = length(resp),
version = object$version)) {
object$family <- lognormal()
}

dots$iter <- first_not_null(dots$iter, object$fit@sim$iter)
# brm computes warmup automatically based on iter
Expand All @@ -1972,11 +1964,10 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) {
recompile <- recompile || length(new_args)
if (recompile) {
if (length(new_args)) {
message(paste("Changing argument(s)",
paste0("'", new_args, "'", collapse = ", "),
"requires recompiling the model"))
message("Changing argument(s) ", collapse_comma(new_args),
" requires recompiling the model")
}
old_args <- setdiff(rc_args, new_args)
old_args <- setdiff(rc_args, c(new_args, "family"))
dots[old_args] <- object[old_args]
if (!is.null(newdata)) {
dots$data <- newdata
Expand Down
32 changes: 20 additions & 12 deletions R/brmsformula.R
Expand Up @@ -445,7 +445,7 @@
#' bf(rt | dec(decision) ~ x, bias = 0.5)
#'
#' @export
brmsformula <- function(formula, ..., flist = NULL,
brmsformula <- function(formula, ..., flist = NULL, family = NULL,
nl = NULL, nonlinear = NULL) {
# ensure backwards compatibility
if (is.brmsformula(formula) && is.formula(formula)) {
Expand Down Expand Up @@ -505,20 +505,23 @@ brmsformula <- function(formula, ..., flist = NULL,
}
out[["nl"]] <- nl
}
if (!is.null(family)) {
out[["family"]] <- check_family(family)
}
# add default values for unspecified elements
defs <- list(pforms = list(), nl = FALSE, family = NULL,
response = NULL, old_mv = FALSE)
defs <- defs[setdiff(names(defs), names(out))]
defs <- list(pforms = list(), pfix = list(), family = NULL,
nl = FALSE, response = NULL, old_mv = FALSE)
defs <- defs[setdiff(names(defs), names(rmNULL(out, FALSE)))]
out[names(defs)] <- defs
class(out) <- "brmsformula"
out
}

#' @export
bf <- function(formula, ..., flist = NULL,
bf <- function(formula, ..., flist = NULL, family = NULL,
nl = NULL, nonlinear = NULL) {
# alias of brmsformula
brmsformula(formula, ..., flist = flist,
brmsformula(formula, ..., flist = flist, family = family,
nl = nl, nonlinear = nonlinear)
}

Expand Down Expand Up @@ -656,10 +659,15 @@ update.brmsformula <- function(object, formula.,
# a brmsformula object
mode <- match.arg(mode)
object <- bf(object)
up_family <- formula.[["family"]]
if (is.null(up_family)) {
up_family <- object[["family"]]
}
up_nl <- formula.[["nl"]]
if (is.null(up_nl)) {
up_nl <- object[["nl"]]
}
# already use up_nl here to avoid ordinary parsing of NL formulas
formula. <- bf(formula., nl = up_nl)
old_form <- object$formula
up_form <- formula.$formula
Expand All @@ -672,12 +680,12 @@ update.brmsformula <- function(object, formula.,
}
pforms <- pforms(object)
up_pforms <- pforms(formula.)
pforms[names(up_pforms)] <- up_pforms

nl <- get_arg("nl", formula., object)
out <- bf(new_form, flist = pforms, nl = nl)
out$family <- get_arg("family", formula., object)
out
pforms[names(up_pforms)] <- up_pforms
pfix <- pfix(object)
up_pfix <- pfix(formula.)
pfix[names(up_pfix)] <- up_pfix
bf(new_form, flist = c(pforms, pfix),
family = up_family, nl = up_nl)
}

#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/extract_draws.R
Expand Up @@ -129,7 +129,7 @@ extract_draws <- function(x, newdata = NULL, re_formula = NULL,
if (nzchar(nlpar)) {
# make sure not to evaluate family specific stuff
# when extracting draws of nlpars
x$formula[["response"]] <- nlpar
x$formula[["response"]] <- nlpar
na_family <- list(family = NA, link = "identity")
class(na_family) <- c("brmsfamily", "family")
x$family <- x$formula$family <- dots$f <- na_family
Expand Down

0 comments on commit ae29077

Please sign in to comment.