Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify extra time implementation #336

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: sdmTMB
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
Version: 0.5.0.9000
Version: 0.5.0.9001
Authors@R: c(
person(c("Sean", "C."), "Anderson", , "sean@seananderson.ca",
role = c("aut", "cre"),
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# sdmTMB (development version)

* Simplify the internal treatment of extra time slices (`extra_time`). #329
This is less bug prone and also fixes a recently created bug. #335

# sdmTMB 0.5.0

* Overhaul residuals vignette ('article')
Expand Down
42 changes: 13 additions & 29 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ NULL
#' `extra_time` can also be used to fill in missing time steps for the purposes
#' of a random walk or AR(1) process if the gaps between time steps are uneven.
#'
#' `extra_time` can include only extra time steps or all time steps including
#' those found in the fitted data. This latter option may be simpler.
#'
#' **Regularization and priors**
#'
#' You can achieve regularization via penalties (priors) on the fixed effect
Expand Down Expand Up @@ -775,19 +778,7 @@ sdmTMB <- function(
offset <- data[[offset]]
}

if (!is.null(extra_time)) { # for forecasting or interpolating
data <- expand_time(df = data, time_slices = extra_time, time_column = time,
weights = weights, offset = offset, upr = upr)
offset <- data[["__sdmTMB_offset__"]] # expanded
weights <- data[["__weight_sdmTMB__"]] # expanded
upr <- data[["__dcens_upr__"]] # expanded
spde$loc_xy <- as.matrix(data[,spde$xy_cols,drop=FALSE])
spde$A_st <- fmesher::fm_basis(spde$mesh, loc = spde$loc_xy)
spde$sdm_spatial_id <- seq(1, nrow(data)) # FIXME?
} else {
data[["__fake_data__"]] <- FALSE
}
check_irregalar_time(data, time, spatiotemporal, time_varying)
check_irregalar_time(data, time, spatiotemporal, time_varying, extra_time = extra_time)

spatial_varying_formula <- spatial_varying # save it
if (!is.null(spatial_varying)) {
Expand All @@ -804,7 +795,6 @@ sdmTMB <- function(
if (length(attr(z_i, "contrasts")) && !.int && !omit_spatial_intercept) { # factors with ~ 0 or ~ -1
msg <- c("Detected predictors with factor levels in `spatial_varying` with the intercept omitted from the `spatial_varying` formula.",
"You likely want to set `spatial = 'off'` since the constant spatial field (`omega_s`) also represents a spatial intercept.`")
# "As of version 0.3.1, sdmTMB turns off the constant spatial field `omega_s` when `spatial_varying` is specified so that the intercept or factor-level means are fully described by the spatially varying random fields `zeta_s`.")
cli_inform(paste(msg, collapse = " "))
}
.int <- grep("(Intercept)", colnames(z_i))
Expand Down Expand Up @@ -1052,7 +1042,8 @@ sdmTMB <- function(
X_ij_list <- list()
for (i in seq_len(n_m)) X_ij_list[[i]] <- X_ij[[i]]

n_t <- length(unique(data[[time]]))
time_df <- make_time_lu(data[[time]], full_time_vec = union(data[[time]], extra_time))
n_t <- nrow(time_df)

random_walk <- if (!is.null(time_varying)) switch(time_varying_type, rw = 1L, rw0 = 2L, ar1 = 0L) else 0L
tmb_data <- list(
Expand All @@ -1064,7 +1055,7 @@ sdmTMB <- function(
A_st = spde$A_st,
sim_re = if ("sim_re" %in% names(experimental)) as.integer(experimental$sim_re) else rep(0L, 6),
A_spatial_index = spde$sdm_spatial_id - 1L,
year_i = make_year_i(data[[time]]),
year_i = time_df$year_i[match(data[[time]], time_df$time_from_data)],
ar1_fields = ar1_fields,
simulate_t = rep(1L, n_t),
rw_fields = rw_fields,
Expand Down Expand Up @@ -1375,16 +1366,9 @@ sdmTMB <- function(
prof <- c("b_j")
if (delta) prof <- c(prof, "b_j2")

lu <- make_year_lu(data[[time]])
fd <- data[['__fake_data__']]
tmp <- data[!fd,,drop=FALSE]
# strip fake data from A matrix:
if (sum(fd) > 0L) spde <- make_mesh(tmp, spde$xy_cols, mesh = spde$mesh)
tmp[['__fake_data__']] <- tmp[['__weight_sdmTMB__']] <-
tmp[['__sdmTMB_offset__']] <- tmp[['__dcens_upr__']] <- NULL
out_structure <- structure(list(
data = tmp,
offset = offset[!fd],
data = data,
offset = offset,
spde = spde,
formula = original_formula,
split_formula = split_formula,
Expand All @@ -1393,10 +1377,10 @@ sdmTMB <- function(
threshold_function = thresh[[1]]$threshold_func,
epsilon_predictor = epsilon_predictor,
time = time,
time_lu = lu,
time_lu = time_df,
family = family,
smoothers = sm,
response = y_i[!fd,,drop=FALSE],
response = y_i,
tmb_data = tmb_data,
tmb_params = tmb_params,
tmb_map = tmb_map,
Expand Down Expand Up @@ -1593,12 +1577,12 @@ parse_spatial_arg <- function(spatial) {
spatial
}

check_irregalar_time <- function(data, time, spatiotemporal, time_varying) {
check_irregalar_time <- function(data, time, spatiotemporal, time_varying, extra_time) {
if (any(spatiotemporal %in% c("ar1", "rw")) || !is.null(time_varying)) {
if (!is.numeric(data[[time]])) {
cli_abort("Time column should be integer or numeric if using AR(1) or random walk processes.")
}
ti <- sort(unique(data[[time]]))
ti <- sort(union(unique(data[[time]]), extra_time))
if (length(unique(diff(ti))) > 1L) {
missed <- find_missing_time(data[[time]])
msg <- c(
Expand Down
21 changes: 9 additions & 12 deletions R/index.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,6 @@ get_generic <- function(obj, value_name, bias_correct = FALSE, level = 0.95,
obj$fit_obj$control$parallel <- 1L
}

# FIXME parallel setup here?
if (!"fake_nd" %in% names(obj)) { # old sdmTMB versions...
predicted_time <- sort(unique(obj$data[[obj$fit_obj$time]]))
fitted_time <- get_fitted_time(obj$fit_obj)
if (!all(fitted_time %in% predicted_time)) {
cli_abort(paste0("Some of the fitted time elements were not predicted ",
"on with `predict.sdmTMB()`. Either supply all time elements to ",
"predict() or update sdmTMB and re-fit your object."))
}
}

assert_that(!is.null(area))
if (length(area) > 1L) {
n_fakend <- if (!is.null(obj$fake_nd)) nrow(obj$fake_nd) else 0L
Expand Down Expand Up @@ -244,7 +233,15 @@ get_generic <- function(obj, value_name, bias_correct = FALSE, level = 0.95,
d$lwr <- as.numeric(trans(d$trans_est + stats::qnorm((1-level)/2) * d$se))
d$upr <- as.numeric(trans(d$trans_est + stats::qnorm(1-(1-level)/2) * d$se))

d[[time_name]] <- get_fitted_time(obj$fit_obj)
if ("pred_tmb_data" %in% names(obj)) { # standard case
ii <- sort(unique(obj$pred_tmb_data$proj_year))
} else { # fit with do_index = TRUE
ii <- sort(unique(obj$fit_obj$tmb_data$proj_year))
}
d <- d[d$est != 0, ,drop=FALSE] # these were not predicted on
lu <- obj$fit_obj$time_lu
tt <- lu$time_from_data[match(ii, lu$year_i)]
d[[time_name]] <- tt
# d$max_gradient <- max(conv$final_grads)
# d$bad_eig <- conv$bad_eig

Expand Down
41 changes: 9 additions & 32 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ predict.sdmTMB <- function(object, newdata = NULL,
xy_cols <- object$spde$xy_cols
}

if (object$version < numeric_version("0.5.0.9001")) {
cli_abort("This model was fit with an older version of sdmTMB before internal handling of `extra_time` was simplified. Please refit your model before predicting on it (or install version 0.5.0 or 0.5.0.9000).")
}

if (is_present(tmbstan_model)) {
deprecate_stop("0.2.2", "predict.sdmTMB(tmbstan_model)", "predict.sdmTMB(mcmc_samples)")
}
Expand Down Expand Up @@ -315,9 +319,6 @@ predict.sdmTMB <- function(object, newdata = NULL,
if (is.null(newdata)) {
if (is_delta(object) || nsim > 0 || type == "response" || !is.null(mcmc_samples) || se_fit || !is.null(re_form) || !is.null(re_form_iid) || !is.null(offset) || isTRUE(object$family$delta)) {
newdata <- object$data
if (!is.null(object$extra_time)) { # issue #273
newdata <- newdata[!newdata[[object$time]] %in% object$extra_time,]
}
nd_arg_was_null <- TRUE # will be used to carry over the offset
}
}
Expand Down Expand Up @@ -345,7 +346,6 @@ predict.sdmTMB <- function(object, newdata = NULL,
tmb_data <- object$tmb_data
tmb_data$do_predict <- 1L
no_spatial <- as.logical(object$tmb_data$no_spatial)
fake_nd <- NULL

if (!is.null(newdata)) {
if (any(!xy_cols %in% names(newdata)) && isFALSE(pop_pred) && !no_spatial)
Expand Down Expand Up @@ -375,31 +375,15 @@ predict.sdmTMB <- function(object, newdata = NULL,
}

check_time_class(object, newdata)
original_time <- as.integer(get_fitted_time(object))
new_data_time <- as.integer(sort(unique(newdata[[object$time]])))
original_time <- object$time_lu$time_from_data
new_data_time <- unique(newdata[[object$time]])

if (!all(new_data_time %in% original_time))
cli_abort(c("Some new time elements were found in `newdata`. ",
"For now, make sure only time elements from the original dataset are present.",
"If you would like to predict on new time elements,",
"see the `extra_time` argument in `?sdmTMB`.")
)

if (!identical(new_data_time, original_time) & isFALSE(pop_pred)) {
missing_time <- original_time[!original_time %in% new_data_time]
fake_nd_list <- list()
fake_nd <- newdata[1L,,drop=FALSE]
for (.t in seq_along(missing_time)) {
fake_nd[[object$time]] <- missing_time[.t]
fake_nd_list[[.t]] <- fake_nd
}
fake_nd <- do.call("rbind", fake_nd_list)
newdata[["_sdmTMB_fake_nd_"]] <- FALSE
fake_nd[["_sdmTMB_fake_nd_"]] <- TRUE
newdata <- rbind(newdata, fake_nd)
if (!is.null(offset)) offset <- c(offset, rep(0, nrow(fake_nd))) # issue 270
}

# If making population predictions (with standard errors), we don't need
# to worry about space, so fill in dummy values if the user hasn't made any:
fake_spatial_added <- FALSE
Expand Down Expand Up @@ -519,7 +503,8 @@ predict.sdmTMB <- function(object, newdata = NULL,
tmb_data$proj_X_ij <- proj_X_ij
tmb_data$proj_X_rw_ik <- proj_X_rw_ik
tmb_data$proj_RE_indexes <- proj_RE_indexes
tmb_data$proj_year <- make_year_i(nd[[object$time]])
time_lu <- object$time_lu
tmb_data$proj_year <- time_lu$year_i[match(nd[[object$time]], time_lu$time_from_data)] # was make_year_i(nd[[object$time]])
tmb_data$proj_lon <- newdata[[xy_cols[[1]]]]
tmb_data$proj_lat <- newdata[[xy_cols[[2]]]]
tmb_data$calc_se <- as.integer(se_fit)
Expand Down Expand Up @@ -690,9 +675,6 @@ predict.sdmTMB <- function(object, newdata = NULL,
}
}

if (!is.null(fake_nd)) {
out <- out[-seq(nrow(out) - nrow(fake_nd) + 1, nrow(out)), ,drop=FALSE] # issue #273
}
return(out)
}

Expand Down Expand Up @@ -853,7 +835,6 @@ predict.sdmTMB <- function(object, newdata = NULL,
nd[[paste0("zeta_s_", object$spatial_varying[z])]] <- r$zeta_s_A[,z,1]
}
nd$epsilon_st <- r$epsilon_st_A_vec[,1]# DELTA FIXME
nd <- nd[!nd[[object$time]] %in% object$extra_time, , drop = FALSE] # issue 270
obj <- object
}

Expand Down Expand Up @@ -886,14 +867,10 @@ predict.sdmTMB <- function(object, newdata = NULL,
nd[["_sdmTMB_time"]] <- NULL
if (no_spatial) nd[["est_rf"]] <- NULL
if (no_spatial) nd[["est_non_rf"]] <- NULL
if ("_sdmTMB_fake_nd_" %in% names(nd)) {
nd <- nd[!nd[["_sdmTMB_fake_nd_"]],,drop=FALSE]
}
nd[["_sdmTMB_fake_nd_"]] <- NULL
row.names(nd) <- NULL

if (return_tmb_object) {
return(list(data = nd, report = r, obj = obj, fit_obj = object, pred_tmb_data = tmb_data, fake_nd = fake_nd))
return(list(data = nd, report = r, obj = obj, fit_obj = object, pred_tmb_data = tmb_data))
} else {
if (visreg_df) {
# for visreg & related, return consistent objects with lm(), gam() etc.
Expand Down
2 changes: 1 addition & 1 deletion R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ print_time_varying <- function(x, m = 1) {
tv_names <- colnames(model.matrix(x$time_varying, x$data))
mm_tv <- cbind(round(as.numeric(b_rw_t_est), 2L), round(as.numeric(b_rw_t_se), 2L))
colnames(mm_tv) <- c("coef.est", "coef.se")
time_slices <- get_fitted_time(x)
time_slices <- x$time_lu$time_from_data
row.names(mm_tv) <- paste(rep(tv_names, each = length(time_slices)), time_slices, sep = "-")
} else {
mm_tv <- NULL
Expand Down
3 changes: 0 additions & 3 deletions R/tmb-sim.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,6 @@ simulate.sdmTMB <- function(object, nsim = 1L, seed = sample.int(1e6, 1L),
}

ret <- do.call(cbind, ret)
if (!is.null(object$extra_time)) {
ret <- ret[seq(1, nrow(object$data)),,drop=FALSE] # drop extra time rows
}
attr(ret, "type") <- type
ret
}
20 changes: 15 additions & 5 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,21 @@ make_year_i <- function(x) {
x - min(x)
}

make_year_lu <- function(x) {
ret <- unique(data.frame(year_i = make_year_i(x), time_from_data = x, stringsAsFactors = FALSE))
ret <- ret[order(ret$year_i),,drop=FALSE]
row.names(ret) <- NULL
ret
make_time_lu <- function(time_vec_from_data, full_time_vec = sort(unique(time_vec_from_data))) {
if (!all(time_vec_from_data %in% full_time_vec)) {
stop("All time elements not in full time vector.")
}
lu <- unique(
data.frame(
year_i = make_year_i(full_time_vec),
time_from_data = full_time_vec,
stringsAsFactors = FALSE
)
)
lu$extra_time <- !lu$time_from_data %in% time_vec_from_data
lu <- lu[order(lu$time_from_data),]
row.names(lu) <- NULL
lu
}

check_offset <- function(formula) {
Expand Down
3 changes: 3 additions & 0 deletions man/sdmTMB.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions src/sdmTMB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ Type objective_function<Type>::operator()()

// Total biomass etc.:
vector<Type> total(n_t);
total.setZero();
total.setZero(); // important; 0s are filtered out after as not predicted on
vector<Type> mu_combined(n_p);
mu_combined.setZero();

Expand All @@ -1235,8 +1235,8 @@ Type objective_function<Type>::operator()()
Type t2;
int link_tmp;

if (n_m > 1) { // delta model
for (int i = 0; i < n_p; i++) {
for (int i = 0; i < n_p; i++) {
if (n_m > 1) { // delta model
if (poisson_link_delta) {
// Type R1 = Type(1.) - exp(-exp(proj_eta(i,0)));
// Type R2 = exp(proj_eta(i,0)) / R1 * exp(proj_eta(i,1))
Expand All @@ -1247,9 +1247,7 @@ Type objective_function<Type>::operator()()
mu_combined(i) = t1 * t2;
}
total(proj_year(i)) += mu_combined(i) * area_i(i);
}
} else { // non-delta model
for (int i = 0; i < n_p; i++) {
} else { // non-delta model
mu_combined(i) = InverseLink(proj_eta(i,0), link(0));
total(proj_year(i)) += mu_combined(i) * area_i(i);
}
Expand Down
Loading
Loading