Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ts_auto_arima() to use "auto_arima" for model spec when .tune = FALSE #489

Closed
spsanderson opened this issue Oct 25, 2023 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@spsanderson
Copy link
Owner

spsanderson commented Oct 25, 2023

Update the model spec to use the engine of "auto_arima" when the .tune parameter is FALSE

Function:

ts_auto_arima <- function(.data, .date_col, .value_col, .formula, .rsamp_obj,
                          .prefix = "ts_arima", .tune = TRUE, .grid_size = 10,
                          .num_cores = 1, .cv_assess = 12, .cv_skip = 3,
                          .cv_slice_limit = 6, .best_metric = "rmse",
                          .bootstrap_final = FALSE){
  
  # Tidyeval ----
  date_col_var_expr <- rlang::enquo(.date_col)
  value_col_var_expr <- rlang::enquo(.value_col)
  sampling_object <- .rsamp_obj
  # Cross Validation
  cv_assess = as.numeric(.cv_assess)
  cv_skip   = as.numeric(.cv_skip)
  cv_slice  = as.numeric(.cv_slice_limit)
  # Tuning Grid
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)
  # Data and splits
  splits <- .rsamp_obj
  data_tbl <- dplyr::as_tibble(.data)
  
  # Checks ----
  if (rlang::quo_is_missing(date_col_var_expr)){
    rlang::abort(
      message = "'.date_col' must be supplied.",
      use_cli_format = TRUE
    )
  }
  
  if (rlang::quo_is_missing(value_col_var_expr)){
    rlang::abort(
      message = "'.value_col' must be supplied.",
      use_cli_format = TRUE
    )
  }
  
  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must be have class rsplit, use the rsample package.",
      use_cli_format = TRUE
    )
  }
  
  # Recipe ----
  # Get the initial recipe call
  recipe_call <- get_recipe_call(match.call())
  
  rec_syntax <- paste0(.prefix, "_recipe") %>%
    assign_value(!!recipe_call)
  
  rec_obj <- recipes::recipe(formula = .formula, data = data_tbl)
  
  # Tune/Spec ----
  if (.tune){
    model_spec <- modeltime::arima_reg(
      seasonal_period            = tune::tune()
      , non_seasonal_ar          = tune::tune()
      , non_seasonal_differences = tune::tune()
      , non_seasonal_ma          = tune::tune()
      , seasonal_ar              = tune::tune()
      , seasonal_differences     = tune::tune()
      , seasonal_ma              = tune::tune()
    ) %>%
      parsnip::set_mode(mode = "regression") %>%
      parsnip::set_engine("arima")
  } else {
    model_spec <- modeltime::arima_reg() %>%
      parsnip::set_mode(mode = "regression") %>%
      parsnip::set_engine("auto_arima")
  }
  
  # if (.tune == TRUE){
  #   model_spec <- model_spec %>%
  #     parsnip::set_mode(mode = "regression") %>%
  #     parsnip::set_engine("arima")
  # } else {
  #   model_spec <- model_spec %>%
  #     parsnip::set_mode(mode = "regression") %>%
  #     parsnip::set_engine("auto_arima")
  # }
  
  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec)
  
  # Tuning Grid ----
  if (.tune){
    
    # Start parallel backend
    modeltime::parallel_start(num_cores)
    
    tuning_grid_spec <- dials::grid_latin_hypercube(
      hardhat::extract_parameter_set_dials(model_spec),
      size = grid_size
    )
    
    # Make TS CV ----
    tscv <- timetk::time_series_cv(
      data        = rsample::training(splits),
      date_var    = {{date_col_var_expr}},
      #date_var = date_col,
      cumulative  = TRUE,
      assess      = cv_assess,
      skip        = cv_skip,
      slice_limit = cv_slice
    )
    
    # Tune the workflow
    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = tscv,
        grid      = tuning_grid_spec,
        metrics   = modeltime::default_forecast_accuracy_metric_set()
      )
    
    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)
    
    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() +
      ggplot2::geom_smooth(se = FALSE)
    
    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = 1)
      ) %>%
      parsnip::fit(rsample::training(splits))
    
    # Stop parallel backend
    modeltime::parallel_stop()
    
  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }
  
  # Calibrate and Plot ----
  cap <- healthyR.ts::calibrate_and_plot(
    wflw_fit,
    .splits_obj  = splits,
    .data        = data_tbl,
    .interactive = TRUE,
    .print_info  = FALSE
  )
  
  # Return ----
  output <- list(
    recipe_info = list(
      recipe_call   = recipe_call,
      recipe_syntax = rec_syntax,
      rec_obj       = rec_obj
    ),
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    ),
    model_calibration = list(
      plot = cap$plot,
      calibration_tbl = cap$calibration_tbl,
      model_accuracy = cap$model_accuracy
    )
  )
  
  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      tscv             = tscv,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }
  
  # Add attributes
  attr(output, ".tune") <- .tune
  attr(output, ".grid_size") <- .grid_size
  attr(output, ".cv_assess") <- .cv_assess
  attr(output, ".cv_skip") <- .cv_skip
  attr(output, ".cv_slice_limit") <- .cv_slice_limit
  attr(output, ".best_metric") <- .best_metric
  attr(output, ".bootstrap_final") <- .bootstrap_final
  attr(output, ".mode") <- "regression"
  attr(output, ".parsnip_engine") <- ifelse(.tune, "arima", "auto_arima")
  attr(output, ".function_family") <- "boilerplate"
  
  return(invisible(output))
}

Examples:

library(dplyr)
library(timetk)
library(modeltime)
library(healthyR.ts)

data <- AirPassengers %>%
 ts_to_tbl() %>%
 select(-index)

splits <- time_series_split(
 data
 , date_col
 , assess = 12
 , skip = 3
 , cumulative = TRUE
)

tsa <- ts_auto_arima(
 .data = data,
 .num_cores = 2,
 .date_col = date_col,
 .value_col = value,
 .rsamp_obj = splits,
 .formula = value ~ .,
 .grid_size = 5,
 .cv_slice_limit = 2,
 .tune = TRUE
)
tsa$model_info$model_spec
ARIMA Regression Model Specification (regression)

Main Arguments:
  seasonal_period = tune::tune()
  non_seasonal_ar = tune::tune()
  non_seasonal_differences = tune::tune()
  non_seasonal_ma = tune::tune()
  seasonal_ar = tune::tune()
  seasonal_differences = tune::tune()
  seasonal_ma = tune::tune()

Computational engine: arima 

tsa <- ts_auto_arima(
 .data = data,
 .num_cores = 2,
 .date_col = date_col,
 .value_col = value,
 .rsamp_obj = splits,
 .formula = value ~ .,
 .grid_size = 5,
 .cv_slice_limit = 2,
 .tune = FALSE
)
frequency = 12 observations per 1 year

ARIMA Regression Model Specification (regression)

Computational engine: auto_arima 
@spsanderson spsanderson added the enhancement New feature or request label Oct 25, 2023
@spsanderson spsanderson added this to the healthyR.ts 0.3 milestone Oct 25, 2023
@spsanderson spsanderson self-assigned this Oct 25, 2023
@spsanderson spsanderson changed the title Update ts_auto_arimt() to use "auto_arima" for model spec when .tune = FALSE Update ts_auto_arima() to use "auto_arima" for model spec when .tune = FALSE Oct 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Development

No branches or pull requests

1 participant