diff --git a/R/brm.R b/R/brm.R index 88af096a3..44c3cf5ba 100644 --- a/R/brm.R +++ b/R/brm.R @@ -34,6 +34,7 @@ #' If not specified, default links are used. #' For details of supported families see #' \code{\link[brms:brmsfamily]{brmsfamily}}. +#' By default, a linear \code{gaussian} model is applied. #' @param prior One or more \code{brmsprior} objects created by #' \code{\link[brms:set_prior]{set_prior}} or related functions #' and combined using the \code{c} method. A single \code{brmsprior} @@ -324,7 +325,7 @@ #' @import methods #' @import stats #' @export -brm <- function(formula, data, family = gaussian(), prior = NULL, +brm <- function(formula, data, family = NULL, prior = NULL, autocor = NULL, nonlinear = NULL, threshold = c("flexible", "equidistant"), cov_ranef = NULL, save_ranef = TRUE, save_mevars = FALSE, @@ -352,7 +353,6 @@ brm <- function(formula, data, family = gaussian(), prior = NULL, "Forking is now automatically applied when appropriate.") } dots[deprecated_brm_args()] <- NULL - check_brm_input(nlist(family, inits)) autocor <- check_autocor(autocor) threshold <- match.arg(threshold) algorithm <- match.arg(algorithm) @@ -369,10 +369,11 @@ brm <- function(formula, data, family = gaussian(), prior = NULL, x$fit <- rstan::get_stanmodel(x$fit) } else { # build new model # see validate.R and formula-helpers.R - family <- check_family(family) formula <- amend_formula(formula, data = data, family = family, nonlinear = nonlinear) - bterms <- parse_bf(formula, family = family, autocor = autocor) + family <- formula$family + check_brm_input(nlist(family)) + bterms <- parse_bf(formula, autocor = autocor) if (is.null(dots$data.name)) { data.name <- substr(Reduce(paste, deparse(substitute(data))), 1, 50) } else { diff --git a/R/brmsfit-helpers.R b/R/brmsfit-helpers.R index a3b9256c6..9711556d1 100644 --- a/R/brmsfit-helpers.R +++ b/R/brmsfit-helpers.R @@ -60,6 +60,13 @@ restructure <- function(x, rstr_summary = FALSE) { # deprecated as of brms 1.4.0 class(x$autocor) <- "cor_fixed" } + if (x$version <= "0.9.1") { + # update gaussian("log") to lognormal() family + nresp <- length(bterms$response) + if (is_old_lognormal(x$family, nresp = nresp, version = x$version)) { + object$family <- object$formula$family <- lognormal() + } + } if (x$version <= "0.10.0.9000") { if (length(bterms$nlpars)) { # nlpar and group have changed positions diff --git a/R/brmsfit-methods.R b/R/brmsfit-methods.R index a089a4b02..e2adb940b 100644 --- a/R/brmsfit-methods.R +++ b/R/brmsfit-methods.R @@ -818,11 +818,11 @@ formula.brmsfit <- function(x, ...) { #' @export family.brmsfit <- function(object, ...) { - if (is(object$family, "family")) { - # brms > 0.6.0 - family <- object$family - } else { + if (is.character(object$family)) { + # brms <= 0.6.0 family <- brmsfamily(object$family, link = object$link) + } else { + family <- object$family } family } @@ -1913,27 +1913,28 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) { if (missing(formula.)) { dots$formula <- object$formula } else { - recompile <- length(pforms(formula.)) > 0L + family <- get_arg("family", dots, formula., object) + nl <- get_arg("nl", formula., formula(object)) + dots$formula <- bf(formula., family = family, nl = nl) + recompile <- length(pforms(dots$formula)) > 0L if (is_nonlinear(object)) { - if (length(setdiff(all.vars(formula.), ".")) == 0L) { - dots$formula <- update(object$formula, formula., mode = "keep") + if (length(setdiff(all.vars(dots$formula$formula), ".")) == 0L) { + dots$formula <- update(object$formula, dots$formula, mode = "keep") } else { - dots$formula <- update(object$formula, formula., mode = "replace") + dots$formula <- update(object$formula, dots$formula, mode = "replace") message("Argument 'formula.' will completely replace the ", "original formula in non-linear models.") recompile <- TRUE } } else { - dots$formula <- as.formula(formula.) - mvars <- setdiff(all.vars(dots$formula), c(names(object$data), ".")) + mvars <- all.vars(dots$formula$formula) + mvars <- setdiff(mvars, c(names(object$data), ".")) if (length(mvars) && is.null(newdata)) { stop2("New variables found: ", paste(mvars, collapse = ", "), "\nPlease supply your data again via argument 'newdata'") } dots$formula <- update(formula(object), dots$formula) ee_old <- parse_bf(formula(object)) - family <- get_arg("family", dots, object) - dots$formula <- amend_formula(dots$formula, family = family) ee_new <- parse_bf(dots$formula) # no need to recompile the model when changing fixed effects only dont_change <- c("random", "gam", "cs", "mo", "me") @@ -1945,22 +1946,13 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) { ee_new[names(ee_new) %in% dont_change]) || is_equal(sort(c(n_old_fixef, n_new_fixef)), c(0L, 1L)) || length(ee_old$response) != length(ee_new$response) || - length(pforms(formula.)) > 0L + length(pforms(formula.)) > 0L || + !identical(dots$formula$family, family(object)) } if (recompile) { message("The desired formula changes require recompling the model") } } - # allow to change the non-linear part via argument 'nonlinear' - take_nl <- !is.null(dots$nonlinear) && - (missing(formula.) || is.null(attr(formula., "nonlinear"))) - if (take_nl) attr(dots$formula, "nonlinear") <- NULL - # update gaussian("log") to lognormal() family - resp <- parse_bf(object$formula, family = object$family)$response - if (is_old_lognormal(object$family, nresp = length(resp), - version = object$version)) { - object$family <- lognormal() - } dots$iter <- first_not_null(dots$iter, object$fit@sim$iter) # brm computes warmup automatically based on iter @@ -1972,11 +1964,10 @@ update.brmsfit <- function(object, formula., newdata = NULL, ...) { recompile <- recompile || length(new_args) if (recompile) { if (length(new_args)) { - message(paste("Changing argument(s)", - paste0("'", new_args, "'", collapse = ", "), - "requires recompiling the model")) + message("Changing argument(s) ", collapse_comma(new_args), + " requires recompiling the model") } - old_args <- setdiff(rc_args, new_args) + old_args <- setdiff(rc_args, c(new_args, "family")) dots[old_args] <- object[old_args] if (!is.null(newdata)) { dots$data <- newdata diff --git a/R/brmsformula.R b/R/brmsformula.R index a6c54eb27..43e161287 100644 --- a/R/brmsformula.R +++ b/R/brmsformula.R @@ -445,7 +445,7 @@ #' bf(rt | dec(decision) ~ x, bias = 0.5) #' #' @export -brmsformula <- function(formula, ..., flist = NULL, +brmsformula <- function(formula, ..., flist = NULL, family = NULL, nl = NULL, nonlinear = NULL) { # ensure backwards compatibility if (is.brmsformula(formula) && is.formula(formula)) { @@ -505,20 +505,23 @@ brmsformula <- function(formula, ..., flist = NULL, } out[["nl"]] <- nl } + if (!is.null(family)) { + out[["family"]] <- check_family(family) + } # add default values for unspecified elements - defs <- list(pforms = list(), nl = FALSE, family = NULL, - response = NULL, old_mv = FALSE) - defs <- defs[setdiff(names(defs), names(out))] + defs <- list(pforms = list(), pfix = list(), family = NULL, + nl = FALSE, response = NULL, old_mv = FALSE) + defs <- defs[setdiff(names(defs), names(rmNULL(out, FALSE)))] out[names(defs)] <- defs class(out) <- "brmsformula" out } #' @export -bf <- function(formula, ..., flist = NULL, +bf <- function(formula, ..., flist = NULL, family = NULL, nl = NULL, nonlinear = NULL) { # alias of brmsformula - brmsformula(formula, ..., flist = flist, + brmsformula(formula, ..., flist = flist, family = family, nl = nl, nonlinear = nonlinear) } @@ -656,10 +659,15 @@ update.brmsformula <- function(object, formula., # a brmsformula object mode <- match.arg(mode) object <- bf(object) + up_family <- formula.[["family"]] + if (is.null(up_family)) { + up_family <- object[["family"]] + } up_nl <- formula.[["nl"]] if (is.null(up_nl)) { up_nl <- object[["nl"]] } + # already use up_nl here to avoid ordinary parsing of NL formulas formula. <- bf(formula., nl = up_nl) old_form <- object$formula up_form <- formula.$formula @@ -672,12 +680,12 @@ update.brmsformula <- function(object, formula., } pforms <- pforms(object) up_pforms <- pforms(formula.) - pforms[names(up_pforms)] <- up_pforms - - nl <- get_arg("nl", formula., object) - out <- bf(new_form, flist = pforms, nl = nl) - out$family <- get_arg("family", formula., object) - out + pforms[names(up_pforms)] <- up_pforms + pfix <- pfix(object) + up_pfix <- pfix(formula.) + pfix[names(up_pfix)] <- up_pfix + bf(new_form, flist = c(pforms, pfix), + family = up_family, nl = up_nl) } #' @export diff --git a/R/extract_draws.R b/R/extract_draws.R index cd982ec70..689a13261 100644 --- a/R/extract_draws.R +++ b/R/extract_draws.R @@ -129,7 +129,7 @@ extract_draws <- function(x, newdata = NULL, re_formula = NULL, if (nzchar(nlpar)) { # make sure not to evaluate family specific stuff # when extracting draws of nlpars - x$formula[["response"]] <- nlpar + x$formula[["response"]] <- nlpar na_family <- list(family = NA, link = "identity") class(na_family) <- c("brmsfamily", "family") x$family <- x$formula$family <- dots$f <- na_family diff --git a/R/families.R b/R/families.R index feb35fada..5d4d2ace8 100644 --- a/R/families.R +++ b/R/families.R @@ -460,7 +460,7 @@ check_family <- function(family, link = NULL) { family <- family() } if (!is(family, "brmsfamily")) { - if (is(family, "family")) { + if (is.family(family)) { link <- family$link family <- family$family } @@ -503,101 +503,101 @@ is.family <- function(x) { is_linear <- function(family) { # indicate if family is for a linear model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("gaussian", "student", "cauchy") + isTRUE(family %in% c("gaussian", "student", "cauchy")) } is_binary <- function(family) { # indicate if family is bernoulli or binomial - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("binomial", "bernoulli") + isTRUE(family %in% c("binomial", "bernoulli")) } is_ordinal <- function(family) { # indicate if family is for an ordinal model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("cumulative", "cratio", "sratio", "acat") + isTRUE(family %in% c("cumulative", "cratio", "sratio", "acat")) } is_categorical <- function(family) { - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% "categorical" + isTRUE(family %in% "categorical") } is_skewed <- function(family) { # indicate if family is for model with postive skewed response - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("gamma", "weibull", "exponential", "frechet") + isTRUE(family %in% c("gamma", "weibull", "exponential", "frechet")) } is_lognormal <- function(family) { # indicate if family is lognormal - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("lognormal") + isTRUE(family %in% c("lognormal")) } is_exgaussian <- function(family) { # indicate if family is exgaussian - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("exgaussian") + isTRUE(family %in% c("exgaussian")) } is_wiener <- function(family) { # indicate if family is the wiener diffusion model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("wiener") + isTRUE(family %in% c("wiener")) } is_asym_laplace <- function(family) { # indicates if family is asymmetric laplace - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("asym_laplace") + isTRUE(family %in% c("asym_laplace")) } is_count <- function(family) { # indicate if family is for a count model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("poisson", "negbinomial", "geometric") + isTRUE(family %in% c("poisson", "negbinomial", "geometric")) } is_hurdle <- function(family, zi_beta = TRUE) { # indicate if family is for a hurdle model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } # zi_beta is technically a hurdle model - family %in% c("hurdle_poisson", "hurdle_negbinomial", "hurdle_gamma", - "hurdle_lognormal", if (zi_beta) "zero_inflated_beta") + isTRUE(family %in% c("hurdle_poisson", "hurdle_negbinomial", "hurdle_gamma", + "hurdle_lognormal", if (zi_beta) "zero_inflated_beta")) } is_zero_inflated <- function(family, zi_beta = FALSE) { # indicate if family is for a zero inflated model - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } # zi_beta is technically a hurdle model - family %in% c("zero_inflated_poisson", "zero_inflated_negbinomial", - "zero_inflated_binomial", if (zi_beta) "zero_inflated_beta") + isTRUE(family %in% c("zero_inflated_poisson", "zero_inflated_negbinomial", + "zero_inflated_binomial", if (zi_beta) "zero_inflated_beta")) } is_2PL <- function(family) { @@ -609,9 +609,8 @@ is_2PL <- function(family) { out <- family$family %in% "bernoulli" && identical(family$type, "2PL") } if (out) { - stop("The special implementation of 2PL models has been removed.\n", - "You can now use argument 'nonlinear' to fit such models.", - call. = FALSE) + stop2("The special implementation of 2PL models has been removed.\n", + "You can now use argument 'nonlinear' to fit such models.") } out } @@ -634,36 +633,36 @@ is_mv <- function(family, response = NULL) { use_real <- function(family) { # indicate if family uses real responses - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } is_linear(family) || is_skewed(family) || - family %in% c("lognormal", "exgaussian", "inverse.gaussian", "beta", - "von_mises", "zero_inflated_beta", "hurdle_gamma", - "hurdle_lognormal", "wiener", "asym_laplace") + isTRUE(family %in% c("lognormal", "exgaussian", "inverse.gaussian", "beta", + "von_mises", "zero_inflated_beta", "hurdle_gamma", + "hurdle_lognormal", "wiener", "asym_laplace")) } use_int <- function(family) { # indicate if family uses integer responses - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } is_binary(family) || has_cat(family) || is_count(family) || is_zero_inflated(family) || - family %in% c("hurdle_poisson", "hurdle_negbinomial") + isTRUE(family %in% c("hurdle_poisson", "hurdle_negbinomial")) } has_trials <- function(family) { # indicate if family makes use of argument trials - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("binomial", "zero_inflated_binomial") + isTRUE(family %in% c("binomial", "zero_inflated_binomial")) } has_cat <- function(family) { # indicate if family makes use of argument cat - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } is_categorical(family) || is_ordinal(family) @@ -671,44 +670,44 @@ has_cat <- function(family) { has_shape <- function(family) { # indicate if family needs a shape parameter - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("gamma", "weibull", "inverse.gaussian", - "negbinomial", "hurdle_negbinomial", - "hurdle_gamma", "zero_inflated_negbinomial") + isTRUE(family %in% c("gamma", "weibull", "inverse.gaussian", + "negbinomial", "hurdle_negbinomial", + "hurdle_gamma", "zero_inflated_negbinomial")) } has_nu <- function(family) { # indicate if family needs a nu parameter - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("student", "frechet") + isTRUE(family %in% c("student", "frechet")) } has_phi <- function(family) { # indicate if family needs a phi parameter - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("beta", "zero_inflated_beta") + isTRUE(family %in% c("beta", "zero_inflated_beta")) } has_kappa <- function(family) { # indicate if family needs a kappa parameter - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("von_mises") + isTRUE(family %in% c("von_mises")) } has_beta <- function(family) { # indicate if family needs a kappa parameter - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("exgaussian") + isTRUE(family %in% c("exgaussian")) } has_sigma <- function(family, bterms = NULL, @@ -719,11 +718,11 @@ has_sigma <- function(family, bterms = NULL, # bterms: object of class brmsterms # autocor: object of class cor_arma # incmv: should MV (linear) models be treated as having sigma? - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - is_ln_eg <- family %in% c("lognormal", "hurdle_lognormal", - "exgaussian", "asym_laplace") + is_ln_eg <- isTRUE(family %in% c("lognormal", "hurdle_lognormal", + "exgaussian", "asym_laplace")) if (is.formula(bterms$se)) { # call .se without evaluating the x argument cl <- rhs(bterms$se)[[2]] @@ -737,7 +736,7 @@ has_sigma <- function(family, bterms = NULL, se_only <- FALSE } out <- (is_linear(family) || is_ln_eg) && - !se_only && !is(autocor, "cov_fixed") + !se_only && !is(autocor, "cov_fixed") if (!incmv) { is_multi <- is_linear(family) && length(bterms$response) > 1L out <- out && !is_multi @@ -747,30 +746,30 @@ has_sigma <- function(family, bterms = NULL, allows_cs <- function(family) { # checks if category specific effects are allowed - if (is(family, "family")) { + if (is.family(family)) { family <- family$family } - family %in% c("sratio", "cratio", "acat") + isTRUE(family %in% c("sratio", "cratio", "acat")) } -is_old_lognormal <- function(family, link = "identity", nresp = 1, +is_old_lognormal <- function(family, link = "identity", nresp = 1L, version = utils::packageVersion("brms")) { # indicate transformation to lognormal models # Args: # link: A character string; ignored if family is of class family # nresp: number of response variables # version: brms version with which the model was fitted - if (is(family, "family")) { + if (is.family(family)) { link <- family$link family <- family$family } - family %in% "gaussian" && link == "log" && nresp == 1 && + isTRUE(family %in% "gaussian") && link == "log" && nresp == 1L && (is.null(version) || version <= "0.9.1") } is_old_categorical <- function(x) { # indicate if the model is and old categorical model - stopifnot(is(x, "brmsfit")) + stopifnot(is.brmsfit(x)) if (is(x$fit, "stanfit") && is_categorical(x$family)) { if ("bp" %in% x$fit@model_pars) { # fitted with brms <= 0.8.0 diff --git a/R/formula-helpers.R b/R/formula-helpers.R index 28c41f82e..b7b9b68df 100644 --- a/R/formula-helpers.R +++ b/R/formula-helpers.R @@ -444,7 +444,7 @@ eval_rhs <- function(formula, data = NULL) { eval(rhs(formula)[[2]], data, environment(formula)) } -amend_formula <- function(formula, data = NULL, family = gaussian(), +amend_formula <- function(formula, data = NULL, family = NULL, nonlinear = NULL, partial = NULL) { # incorporate additional arguments into formula # Args: @@ -454,8 +454,7 @@ amend_formula <- function(formula, data = NULL, family = gaussian(), # nonlinear, partial: deprecated arguments of brm # Returns: # a brmsformula object compatible with the current version of brms - out <- bf(formula, nonlinear = nonlinear) - out[["family"]] <- family + out <- bf(formula, family = family, nonlinear = nonlinear) fnew <- ". ~ ." if (!is.null(partial)) { warning2("Argument 'partial' is deprecated. Please use the 'cs' ", @@ -471,13 +470,16 @@ amend_formula <- function(formula, data = NULL, family = gaussian(), if (fnew != ". ~ .") { out$formula <- update.formula(out$formula, formula(fnew)) } - if (is_ordinal(family)) { + if (is.null(out$family)) { + out$family <- check_family(gaussian()) + } + if (is_ordinal(out$family)) { # fix discrimination to 1 by default if (!"disc" %in% c(names(pforms(out)), names(pfix(out)))) { out <- bf(out, disc = 1) } } - if (is_categorical(family) && is.null(attr(formula, "response"))) { + if (is_categorical(out$family) && is.null(out[["response"]])) { respform <- parse_bf(out)$respform model_response <- model.response(model.frame(respform, data = data)) response <- levels(factor(model_response)) diff --git a/R/make_stancode.R b/R/make_stancode.R index a81a832ab..fd372add1 100644 --- a/R/make_stancode.R +++ b/R/make_stancode.R @@ -17,7 +17,7 @@ #' data = epilepsy, family = "poisson") #' #' @export -make_stancode <- function(formula, data, family = gaussian(), +make_stancode <- function(formula, data, family = NULL, prior = NULL, autocor = NULL, nonlinear = NULL, threshold = c("flexible", "equidistant"), sparse = FALSE, cov_ranef = NULL, @@ -30,9 +30,9 @@ make_stancode <- function(formula, data, family = gaussian(), save_model <- use_alias(save_model, dots$save.model) dots[c("cov.ranef", "sample.prior", "save.model")] <- NULL # some input checks - family <- check_family(family) formula <- amend_formula(formula, data = data, family = family, nonlinear = nonlinear) + family <- formula$family autocor <- check_autocor(autocor) threshold <- match.arg(threshold) bterms <- parse_bf(formula, family = family, autocor = autocor) diff --git a/R/make_standata.R b/R/make_standata.R index 39f980096..561c19494 100644 --- a/R/make_standata.R +++ b/R/make_standata.R @@ -24,7 +24,7 @@ #' names(data2) #' #' @export -make_standata <- function(formula, data, family = "gaussian", +make_standata <- function(formula, data, family = NULL, prior = NULL, autocor = NULL, nonlinear = NULL, cov_ranef = NULL, sample_prior = FALSE, knots = NULL, control = list(), ...) { @@ -40,9 +40,9 @@ make_standata <- function(formula, data, family = "gaussian", # use deprecated arguments if specified cov_ranef <- use_alias(cov_ranef, dots$cov.ranef, warn = FALSE) # some input checks - family <- check_family(family) formula <- amend_formula(formula, data = data, family = family, - nonlinear = nonlinear) + nonlinear = nonlinear) + family <- formula$family old_mv <- isTRUE(formula[["old_mv"]]) autocor <- check_autocor(autocor) is_linear <- is_linear(family) diff --git a/R/misc.R b/R/misc.R index c214ef840..f2aa66fc1 100644 --- a/R/misc.R +++ b/R/misc.R @@ -50,6 +50,10 @@ isFALSE <- function(x) { identical(FALSE, x) } +isNA <- function(x) { + identical(NA, x) +} + is_equal <- function(x, y, ...) { isTRUE(all.equal(x, y, ...)) } diff --git a/R/priors.R b/R/priors.R index b55fc25a9..0e4caa5d4 100644 --- a/R/priors.R +++ b/R/priors.R @@ -502,15 +502,15 @@ prior_string <- function(prior, ...) { #' prior = prior) #' #' @export -get_prior <- function(formula, data, family = gaussian(), +get_prior <- function(formula, data, family = NULL, autocor = NULL, nonlinear = NULL, threshold = c("flexible", "equidistant"), internal = FALSE) { # note that default priors are stored in this function - family <- check_family(family) - link <- family$link formula <- amend_formula(formula, data = data, family = family, nonlinear = nonlinear) + family <- formula$family + link <- family$link threshold <- match.arg(threshold) autocor <- check_autocor(autocor) bterms <- parse_bf(formula, family = family) @@ -937,7 +937,7 @@ check_prior_content <- function(prior, family = gaussian(), warn = TRUE) { if (!is(prior, "brmsprior")) { return(invisible(NULL)) } - stopifnot(is(family, "family")) + stopifnot(is.family(family)) family <- family$family if (nrow(prior)) { lb_priors <- c("lognormal", "chi_square", "inv_chi_square", diff --git a/R/validate.R b/R/validate.R index 196eb0d30..1f6e18f3a 100644 --- a/R/validate.R +++ b/R/validate.R @@ -28,21 +28,16 @@ #' \code{\link[brms:brmsformula]{brmsformula}} #' #' @export -parse_bf <- function(formula, family = NA, autocor = NULL, +parse_bf <- function(formula, family = NULL, autocor = NULL, check_response = TRUE, resp_rhs_all = TRUE) { - x <- bf(formula) + x <- bf(formula, family = family) old_mv <- isTRUE(x[["old_mv"]]) - if (!is.null(x[["family"]])) { - family <- x[["family"]] - } - if (!is.na(family[[1]])) { - family <- check_family(family) - } if (!(is.null(autocor) || is.cor_brms(autocor))) { stop2("Argument 'autocor' has to be of class 'cor_brms'") } formula <- x$formula + family <- x$family y <- nlist(formula) add_forms <- parse_add(formula, family, check_response) add_vars <- str2formula(ulapply(add_forms, all.vars)) @@ -187,7 +182,7 @@ parse_bf <- function(formula, family = NA, autocor = NULL, y } -parse_add <- function(formula, family = NA, check_response = TRUE) { +parse_add <- function(formula, family = NULL, check_response = TRUE) { # extract addition arguments out formula # Args: # see parse_bf @@ -196,7 +191,7 @@ parse_add <- function(formula, family = NA, check_response = TRUE) { x <- list() add_funs <- lsp("brms", what = "exports", pattern = "^resp_") add_funs <- sub("^resp_", "", add_funs) - if (!is.na(family[[1]])) { + if (!is.null(family) && !isNA(family$family)) { add <- get_matches("\\|[^~]*~", formula2str(formula)) if (length(add)) { # replace deprecated '|' by '+' @@ -259,13 +254,13 @@ parse_mo <- function(formula) { structure(mo_terms, pos = pos_mo_terms) } -parse_cs <- function(formula, family = NA) { +parse_cs <- function(formula, family = NULL) { # category specific terms for ordinal models all_terms <- all_terms(formula) pos_cs_terms <- grepl("^cse?\\([^\\|]+$", all_terms) cs_terms <- all_terms[pos_cs_terms] if (length(cs_terms)) { - if (!is.na(family[[1]]) && !allows_cs(family)) { + if (!is.null(family) && !allows_cs(family)) { stop2("Category specific effects are only meaningful for ", "families 'sratio', 'cratio', and 'acat'.") } diff --git a/man/brm.Rd b/man/brm.Rd index ba249a35e..242c1ff16 100644 --- a/man/brm.Rd +++ b/man/brm.Rd @@ -4,7 +4,7 @@ \alias{brm} \title{Fit Bayesian Generalized (Non-)Linear Multilevel Models} \usage{ -brm(formula, data, family = gaussian(), prior = NULL, autocor = NULL, +brm(formula, data, family = NULL, prior = NULL, autocor = NULL, nonlinear = NULL, threshold = c("flexible", "equidistant"), cov_ranef = NULL, save_ranef = TRUE, save_mevars = FALSE, sparse = FALSE, sample_prior = FALSE, knots = NULL, stan_funs = NULL, @@ -33,7 +33,8 @@ Every family function has a \code{link} argument allowing to specify the link function to be applied on the response variable. If not specified, default links are used. For details of supported families see -\code{\link[brms:brmsfamily]{brmsfamily}}.} +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} \item{prior}{One or more \code{brmsprior} objects created by \code{\link[brms:set_prior]{set_prior}} or related functions diff --git a/man/brmsformula.Rd b/man/brmsformula.Rd index 38f0e59c2..611fefd5a 100644 --- a/man/brmsformula.Rd +++ b/man/brmsformula.Rd @@ -5,7 +5,8 @@ \alias{brmsformula} \title{Set up a model formula for use in \pkg{brms}} \usage{ -brmsformula(formula, ..., flist = NULL, nl = NULL, nonlinear = NULL) +brmsformula(formula, ..., flist = NULL, family = NULL, nl = NULL, + nonlinear = NULL) } \arguments{ \item{formula}{An object of class \code{formula} @@ -46,6 +47,16 @@ See 'Details' for more explanation.} \item{flist}{Optional list of formulas, which are treated in the same way as formulas passed via the \code{...} argument.} +\item{family}{A description of the response distribution and link function +to be used in the model. This can be a family function, +a call to a family function or a character string naming the family. +Every family function has a \code{link} argument allowing to specify +the link function to be applied on the response variable. +If not specified, default links are used. +For details of supported families see +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} + \item{nl}{Logical; Indicates whether \code{formula} should be treated as specifying a non-linear model. By default, \code{formula} is treated as an ordinary linear model formula.} diff --git a/man/get_prior.Rd b/man/get_prior.Rd index b64422bc3..a28482968 100644 --- a/man/get_prior.Rd +++ b/man/get_prior.Rd @@ -4,9 +4,8 @@ \alias{get_prior} \title{Overview on Priors for \pkg{brms} Models} \usage{ -get_prior(formula, data, family = gaussian(), autocor = NULL, - nonlinear = NULL, threshold = c("flexible", "equidistant"), - internal = FALSE) +get_prior(formula, data, family = NULL, autocor = NULL, nonlinear = NULL, + threshold = c("flexible", "equidistant"), internal = FALSE) } \arguments{ \item{formula}{An object of class @@ -28,7 +27,8 @@ Every family function has a \code{link} argument allowing to specify the link function to be applied on the response variable. If not specified, default links are used. For details of supported families see -\code{\link[brms:brmsfamily]{brmsfamily}}.} +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} \item{autocor}{An optional \code{\link{cor_brms}} object describing the correlation structure within the response variable diff --git a/man/make_stancode.Rd b/man/make_stancode.Rd index 87e85c014..8e3fc9266 100644 --- a/man/make_stancode.Rd +++ b/man/make_stancode.Rd @@ -4,9 +4,9 @@ \alias{make_stancode} \title{Stan Code for \pkg{brms} Models} \usage{ -make_stancode(formula, data, family = gaussian(), prior = NULL, - autocor = NULL, nonlinear = NULL, threshold = c("flexible", - "equidistant"), sparse = FALSE, cov_ranef = NULL, sample_prior = FALSE, +make_stancode(formula, data, family = NULL, prior = NULL, autocor = NULL, + nonlinear = NULL, threshold = c("flexible", "equidistant"), + sparse = FALSE, cov_ranef = NULL, sample_prior = FALSE, stan_funs = NULL, save_model = NULL, ...) } \arguments{ @@ -29,7 +29,8 @@ Every family function has a \code{link} argument allowing to specify the link function to be applied on the response variable. If not specified, default links are used. For details of supported families see -\code{\link[brms:brmsfamily]{brmsfamily}}.} +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} \item{prior}{One or more \code{brmsprior} objects created by \code{\link[brms:set_prior]{set_prior}} or related functions diff --git a/man/make_standata.Rd b/man/make_standata.Rd index e6d1536c4..bfa23708a 100644 --- a/man/make_standata.Rd +++ b/man/make_standata.Rd @@ -5,9 +5,9 @@ \alias{make_standata} \title{Data for \pkg{brms} Models} \usage{ -make_standata(formula, data, family = "gaussian", prior = NULL, - autocor = NULL, nonlinear = NULL, cov_ranef = NULL, - sample_prior = FALSE, knots = NULL, control = list(), ...) +make_standata(formula, data, family = NULL, prior = NULL, autocor = NULL, + nonlinear = NULL, cov_ranef = NULL, sample_prior = FALSE, + knots = NULL, control = list(), ...) } \arguments{ \item{formula}{An object of class @@ -29,7 +29,8 @@ Every family function has a \code{link} argument allowing to specify the link function to be applied on the response variable. If not specified, default links are used. For details of supported families see -\code{\link[brms:brmsfamily]{brmsfamily}}.} +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} \item{prior}{One or more \code{brmsprior} objects created by \code{\link[brms:set_prior]{set_prior}} or related functions diff --git a/man/parse_bf.Rd b/man/parse_bf.Rd index bc7e7e846..4a69f50f7 100644 --- a/man/parse_bf.Rd +++ b/man/parse_bf.Rd @@ -4,7 +4,7 @@ \alias{parse_bf} \title{Parse Formulas of \pkg{brms} Models} \usage{ -parse_bf(formula, family = NA, autocor = NULL, check_response = TRUE, +parse_bf(formula, family = NULL, autocor = NULL, check_response = TRUE, resp_rhs_all = TRUE) } \arguments{ @@ -23,7 +23,8 @@ Every family function has a \code{link} argument allowing to specify the link function to be applied on the response variable. If not specified, default links are used. For details of supported families see -\code{\link[brms:brmsfamily]{brmsfamily}}.} +\code{\link[brms:brmsfamily]{brmsfamily}}. +By default, a linear \code{gaussian} model is applied.} \item{autocor}{An optional \code{\link{cor_brms}} object describing the correlation structure within the response variable diff --git a/tests/testthat/tests.brmsfit-methods.R b/tests/testthat/tests.brmsfit-methods.R index cecc13cb0..a99a6b7b8 100644 --- a/tests/testthat/tests.brmsfit-methods.R +++ b/tests/testthat/tests.brmsfit-methods.R @@ -360,6 +360,10 @@ test_that("all S3 methods have reasonable ouputs", { expect_true(is(up, "brmsfit")) up <- update(fit2, formula. = count ~ a + b, testmode = TRUE) expect_true(is(up, "brmsfit")) + up <- update(fit3, family = acat(), testmode = TRUE) + expect_true(is(up, "brmsfit")) + up <- update(fit3, bf(~., family = acat()), testmode = TRUE) + expect_true(is(up, "brmsfit")) # VarCorr vc <- VarCorr(fit1) diff --git a/tests/testthat/tests.data-helpers.R b/tests/testthat/tests.data-helpers.R index 547631ca4..cde539e2e 100644 --- a/tests/testthat/tests.data-helpers.R +++ b/tests/testthat/tests.data-helpers.R @@ -38,9 +38,9 @@ test_that("(deprecated) melt_data keeps factor contrasts", { test_that("(deprecated) melt_data returns expected errors", { data <- data.frame(y1 = rnorm(10), y2 = rnorm(10), x = 1:10) - formula <- bf(y1 ~ x:main) + formula <- bf(y1 ~ x:main, family = hurdle_poisson()) formula$old_mv <- TRUE - bterms <- brms:::parse_bf(formula, family = hurdle_poisson()) + bterms <- parse_bf(formula) expect_error(melt_data(data = NULL, family = hurdle_poisson(), bterms = bterms), "'data' must be a data.frame", fixed = TRUE) @@ -49,16 +49,16 @@ test_that("(deprecated) melt_data returns expected errors", { "'main' is a reserved variable name", fixed = TRUE) data$response <- 1:10 - formula <- bf(response ~ x:main) + formula <- bf(response ~ x:main, family = hurdle_poisson()) formula$old_mv <- TRUE - bterms <- parse_bf(formula, family = hurdle_poisson()) + bterms <- parse_bf(formula) expect_error(melt_data(data = data, family = hurdle_poisson(), bterms = bterms), "'response' is a reserved variable name", fixed = TRUE) data$trait <- 1:10 - formula <- bf(y ~ 0 + x*trait) + formula <- bf(y ~ 0 + x*trait, family = hurdle_poisson()) formula$old_mv <- TRUE - bterms <- parse_bf(formula, family = hurdle_poisson()) + bterms <- parse_bf(formula) expect_error(melt_data(data = data, family = hurdle_poisson(), bterms = bterms), "'trait', 'response' is a reserved variable name", fixed = TRUE) @@ -114,9 +114,9 @@ test_that("(deprecated) update_data handles NAs correctly in old MV models", { expect_equivalent(mf, data.frame(response = c(1, 3, 4, 6), y1 = c(1, 3, 1, 3), y2 = c(4, 6, 4, 6), x = c(10, 12, 10, 12))) - formula <- bf(y1 ~ x) + formula <- bf(y1 ~ x, family = "hurdle_gamma") formula$old_mv <- TRUE - bterms <- parse_bf(formula, family = "hurdle_gamma") + bterms <- parse_bf(formula) expect_warning(mf <- update_data(data, family = "hurdle_gamma", bterms = bterms), "NAs were excluded") @@ -124,7 +124,8 @@ test_that("(deprecated) update_data handles NAs correctly in old MV models", { y1 = c(1, 3, 1, 3), x = c(10, 12, 10, 12))) - bterms <- parse_bf(formula, family = "zero_inflated_poisson") + formula$family <- zero_inflated_poisson() + bterms <- parse_bf(formula) expect_warning(mf <- update_data(data, family = "zero_inflated_poisson", bterms = bterms), "NAs were excluded") expect_equivalent(mf, data.frame(response = c(1, 3, 1, 3), y1 = c(1, 3, 1, 3),