Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
fable/R/lagwalk.R
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
447 lines (413 sloc)
11.8 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_lagwalk <- function(.data, specials, ...) { | |
if (length(measured_vars(.data)) > 1) { | |
abort("Only univariate responses are supported by lagwalks.") | |
} | |
y <- unclass(.data)[[measured_vars(.data)]] | |
n <- length(y) | |
if (all(is.na(y))) { | |
abort("All observations are missing, a model cannot be estimated without data.") | |
} | |
drift <- specials$drift[[1]][[1]] %||% FALSE | |
fixed <- specials$drift[[1]][[2]] | |
lag <- specials$lag[[1]] | |
y_na <- which(is.na(y)) | |
y_na <- y_na[y_na > lag] | |
fits <- stats::lag(y, -lag) | |
for (i in y_na) { | |
if (is.na(fits)[i]) { | |
fits[i] <- fits[i - lag] | |
} | |
} | |
fitted <- c(rep(NA, min(lag, n)), utils::head(fits, -lag)) | |
# Initial model estimation or re-estimation of RW model (with drift). | |
if (drift) { | |
if (!rlang::is_null(fixed)) { | |
b <- fixed | |
b.se <- dbl() # updated in refit.RW. | |
} else { | |
fit <- summary(stats::lm(y - fitted ~ 1, na.action = stats::na.exclude)) | |
b <- fit$coefficients[1, 1] | |
b.se <- fit$coefficients[1, 2] | |
} | |
fitted <- fitted + b | |
} else { | |
# No drift model. | |
b <- b.se <- dbl() | |
} | |
sigma <- stats::sd(y - fitted, na.rm = TRUE) | |
res <- y - fitted | |
structure( | |
list( | |
b = b, | |
b.se = b.se, | |
lag = lag, | |
sigma2 = sigma^2, | |
.fitted = fitted, | |
.resid = res, | |
time = list(start = unclass(.data)[[index_var(.data)]][[1]], interval = interval(.data)), | |
future = y[c(rep(NA, max(0, lag - n)), seq_len(min(n, lag)) + n - min(n, lag))] | |
), | |
class = "RW" | |
) | |
} | |
#' Random walk models | |
#' | |
#' \code{RW()} returns a random walk model, which is equivalent to an ARIMA(0,1,0) | |
#' model with an optional drift coefficient included using \code{drift()}. \code{naive()} is simply a wrapper | |
#' to \code{rwf()} for simplicity. \code{snaive()} returns forecasts and | |
#' prediction intervals from an ARIMA(0,0,0)(0,1,0)m model where m is the | |
#' seasonal period. | |
#' | |
#' The random walk with drift model is \deqn{Y_t=c + Y_{t-1} + Z_t}{Y[t]=c + | |
#' Y[t-1] + Z[t]} where \eqn{Z_t}{Z[t]} is a normal iid error. Forecasts are | |
#' given by \deqn{Y_n(h)=ch+Y_n}{Y[n+h]=ch+Y[n]}. If there is no drift (as in | |
#' \code{naive}), the drift parameter c=0. Forecast standard errors allow for | |
#' uncertainty in estimating the drift parameter (unlike the corresponding | |
#' forecasts obtained by fitting an ARIMA model directly). | |
#' | |
#' The seasonal naive model is \deqn{Y_t= Y_{t-m} + Z_t}{Y[t]=Y[t-m] + Z[t]} | |
#' where \eqn{Z_t}{Z[t]} is a normal iid error. | |
#' | |
#' @aliases report.RW | |
#' | |
#' @param formula Model specification (see "Specials" section). | |
#' @param ... Not used. | |
#' | |
#' @section Specials: | |
#' | |
#' \subsection{lag}{ | |
#' The `lag` special is used to specify the lag order for the random walk process. | |
#' If left out, this special will automatically be included. | |
#' | |
#' \preformatted{ | |
#' lag(lag = NULL) | |
#' } | |
#' | |
#' \tabular{ll}{ | |
#' `lag` \tab The lag order for the random walk process. If `lag = m`, forecasts will return the observation from `m` time periods ago. This can also be provided as text indicating the duration of the lag window (for example, annual seasonal lags would be "1 year"). | |
#' } | |
#' } | |
#' | |
#' \subsection{drift}{ | |
#' The `drift` special can be used to include a drift/trend component into the model. By default, drift is not included unless `drift()` is included in the formula. | |
#' | |
#' \preformatted{ | |
#' drift(drift = TRUE) | |
#' } | |
#' | |
#' \tabular{ll}{ | |
#' `drift` \tab If `drift = TRUE`, a drift term will be included in the model. | |
#' } | |
#' } | |
#' | |
#' @return A model specification. | |
#' | |
#' @seealso | |
#' [Forecasting: Principles and Practices, Some simple forecasting methods (section 3.2)](https://otexts.com/fpp3/simple-methods.html) | |
#' | |
#' @examples | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(rw = RW(Beer ~ drift())) | |
#' @export | |
RW <- function(formula, ...) { | |
rw_model <- new_model_class("RW", | |
train = train_lagwalk, | |
specials = new_specials( | |
lag = function(lag = NULL) { | |
if (is.null(lag)) { | |
lag <- 1 | |
} | |
if (!rlang::is_integerish(lag)) { | |
warn("Non-integer lag orders for random walk models are not supported. Rounding to the nearest integer.") | |
lag <- round(lag) | |
} | |
get_frequencies(lag, self$data, .auto = "smallest") | |
}, | |
drift = function(drift = TRUE, fixed = NULL) { | |
list(drift = drift, fixed = fixed) | |
}, | |
xreg = no_xreg, | |
.required_specials = c("lag") | |
), | |
check = all_tsbl_checks | |
) | |
new_model_definition(rw_model, !!enquo(formula), ...) | |
} | |
#' @rdname RW | |
#' | |
#' @examples | |
#' | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) | |
#' @export | |
NAIVE <- RW | |
#' @rdname RW | |
#' | |
#' @examples | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) | |
#' @export | |
SNAIVE <- function(formula, ...) { | |
snaive_model <- new_model_class("RW", | |
train = train_lagwalk, | |
specials = new_specials( | |
lag = function(lag = NULL) { | |
lag <- get_frequencies(lag, self$data, .auto = "smallest") | |
if (lag == 1) { | |
abort("Non-seasonal model specification provided, use RW() or provide a different lag specification.") | |
} | |
if (!rlang::is_integerish(lag)) { | |
warn("Non-integer lag orders for random walk models are not supported. Rounding to the nearest integer.") | |
lag <- round(lag) | |
} | |
lag | |
}, | |
drift = function(drift = TRUE, fixed = NULL) { | |
list(drift = drift, fixed = fixed) | |
}, | |
xreg = no_xreg, | |
.required_specials = c("lag") | |
), | |
check = all_tsbl_checks | |
) | |
new_model_definition(snaive_model, !!enquo(formula), ...) | |
} | |
#' @inherit forecast.ARIMA | |
#' @inheritParams forecast.ETS | |
#' @importFrom stats qnorm time | |
#' @importFrom utils tail | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' forecast() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' forecast() | |
#' @export | |
forecast.RW <- function(object, new_data, specials = NULL, simulate = FALSE, bootstrap = FALSE, times = 5000, ...) { | |
h <- NROW(new_data) | |
lag <- object$lag | |
fullperiods <- (h - 1) / lag + 1 | |
steps <- rep(1:fullperiods, rep(lag, fullperiods))[1:h] | |
b <- object$b | |
b.se <- object$b.se | |
if (is_empty(b)) { | |
b <- b.se <- 0 | |
} | |
# Produce forecasts | |
if (simulate || bootstrap) { # Compute prediction intervals using simulations | |
sim <- map(seq_len(times), function(x) { | |
generate(object, new_data, bootstrap = bootstrap)[[".sim"]] | |
}) %>% | |
transpose() %>% | |
map(as.numeric) | |
distributional::dist_sample(sim) | |
} else { | |
fc <- rep(object$future, fullperiods)[1:h] + steps * b | |
res <- residuals(object) | |
mse <- sum(res^2, na.rm = TRUE)/(sum(!is.na(res)) - (b != 0)) | |
if (is.nan(mse)) mse <- NA | |
# Adjust prediction intervals to allow for drift coefficient standard error | |
se <- sqrt(mse * steps + (steps * b.se)^2) | |
distributional::dist_normal(fc, se) | |
} | |
} | |
#' @inherit generate.ETS | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' generate() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' generate() | |
#' @export | |
generate.RW <- function(x, new_data, bootstrap = FALSE, ...) { | |
if (!is_regular(new_data)) { | |
abort("Simulation new_data must be regularly spaced") | |
} | |
lag <- x$lag | |
if (!is_empty(x$b)) { | |
b <- stats::rnorm(1, mean = x$b, sd = x$b.se) | |
} else { | |
b <- 0 | |
} | |
fits <- c(x$.fitted, x$future) | |
start_idx <- min(new_data[[index_var(new_data)]]) | |
start_pos <- match(start_idx, seq(x$time$start, by = default_time_units(x$time$interval), length.out = length(fits))) | |
future <- fits[start_pos + seq_len(lag) - 1] | |
if (any(is.na(future))) { | |
abort("The first lag window for simulation must be within the model's training set.") | |
} | |
if (!(".innov" %in% names(new_data))) { | |
if (bootstrap) { | |
new_data$.innov <- sample(stats::na.omit(residuals(x) - mean(residuals(x), na.rm = TRUE)), | |
NROW(new_data), | |
replace = TRUE | |
) | |
} | |
else { | |
new_data$.innov <- stats::rnorm(NROW(new_data), sd = sqrt(x$sigma2)) | |
} | |
} | |
sim_rw <- function(e) { | |
# Cumulate errors | |
dx <- e + b | |
lag_grp <- rep_len(seq_len(lag), length(dx)) | |
dx <- split(dx, lag_grp) | |
cumulative_e <- unsplit(lapply(dx, cumsum), lag_grp) | |
rep_len(future, length(dx)) + cumulative_e | |
} | |
transmute(group_by_key(new_data), ".sim" := sim_rw(!!sym(".innov"))) | |
} | |
#' @inherit fitted.ARIMA | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' fitted() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' fitted() | |
#' @export | |
fitted.RW <- function(object, ...) { | |
object[[".fitted"]] | |
} | |
#' @inherit residuals.ARIMA | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' residuals() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' residuals() | |
#' @export | |
residuals.RW <- function(object, ...) { | |
object[[".resid"]] | |
} | |
#' Glance a lag walk model | |
#' | |
#' Construct a single row summary of the lag walk model. | |
#' Contains the variance of residuals (`sigma2`). | |
#' | |
#' @inheritParams generics::glance | |
#' | |
#' @return A one row tibble summarising the model's fit. | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' glance() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' glance() | |
#' @export | |
glance.RW <- function(x, ...) { | |
tibble(sigma2 = x[["sigma2"]]) | |
} | |
#' @inherit tidy.ARIMA | |
#' | |
#' @examples | |
#' as_tsibble(Nile) %>% | |
#' model(NAIVE(value)) %>% | |
#' tidy() | |
#' | |
#' library(tsibbledata) | |
#' aus_production %>% | |
#' model(snaive = SNAIVE(Beer ~ lag("year"))) %>% | |
#' tidy() | |
#' @export | |
tidy.RW <- function(x, ...) { | |
drift <- !is_empty(x$b) | |
tibble( | |
term = if (drift) "b" else chr(), | |
estimate = x$b, std.error = x$b.se, | |
statistic = x$b / x$b.se, | |
p.value = 2 * stats::pt(abs(x$b / x$b.se), length(x$.resid) - x$lag - drift, lower.tail = FALSE) | |
) | |
} | |
#' @export | |
report.RW <- function(object, ...) { | |
cat("\n") | |
if (!is_empty(object[["b"]])) { | |
cat(paste("Drift: ", round(object[["b"]], 4), | |
" (se: ", round(object[["b.se"]], 4), ")\n", | |
sep = "" | |
)) | |
} | |
cat(paste("sigma^2:", round(object[["sigma2"]], 4), "\n")) | |
} | |
#' @importFrom stats coef | |
#' @export | |
model_sum.RW <- function(x) { | |
drift <- !is_empty(x[["b"]]) | |
if (x[["lag"]] == 1 && !drift) { | |
method <- "NAIVE" | |
} | |
else if (x[["lag"]] != 1) { | |
method <- "SNAIVE" | |
} | |
else { | |
method <- "RW" | |
} | |
if (drift) { | |
method <- paste(method, "w/ drift") | |
} | |
method | |
} | |
#' Refit a lag walk model | |
#' | |
#' Applies a fitted random walk model to a new dataset. | |
#' | |
#' The models `NAIVE` and `SNAIVE` have no specific model parameters. Using `refit` | |
#' for one of these models will provide the same estimation results as one would | |
#' use `fabletools::model(NAIVE(...))` (or `fabletools::model(SNAIVE(...))`. | |
#' | |
#' @inheritParams refit.ARIMA | |
#' @param reestimate If `TRUE`, the lag walk model will be re-estimated | |
#' to suit the new data. | |
#' | |
#' @examples | |
#' lung_deaths_male <- as_tsibble(mdeaths) | |
#' lung_deaths_female <- as_tsibble(fdeaths) | |
#' | |
#' fit <- lung_deaths_male %>% | |
#' model(RW(value ~ drift())) | |
#' | |
#' report(fit) | |
#' | |
#' fit %>% | |
#' refit(lung_deaths_female) %>% | |
#' report() | |
#' @export | |
refit.RW <- function(object, new_data, specials = NULL, reestimate = FALSE, ...) { | |
# Update specials 'lag'. | |
specials$lag <- object$lag | |
# Case if reestimate = TRUE. | |
if (reestimate) { | |
return(train_lagwalk(new_data, specials, ...)) | |
} | |
# Case if reestimate = FALSE. | |
# Update fixed. | |
if (!rlang::is_empty(object$b)) { | |
specials$drift[[1]][[2]] <- object$b | |
} | |
refit <- train_lagwalk(new_data, specials, ...) | |
# b.se could be either a numeric value or an empty numeric (dbl()). | |
refit$b.se <- object$b.se | |
return(refit) | |
} |