-
Notifications
You must be signed in to change notification settings - Fork 819
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Base version of TBATS * Add both BATS and TBATS * import both bats and tbats * fix an issue * added unit tests for bats and tbats probabilistic * add accuracy unit tests * Remove useless lines * Improve doc * Small fix * Add BATS/TBATS to pmdarima flavour * better frequency support * Update darts/models/forecasting/tbats.py Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * Update darts/models/forecasting/tbats.py Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * some PR comments * address PR comments Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
- Loading branch information
1 parent
d48f265
commit 16bece0
Showing
5 changed files
with
287 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
""" | ||
BATS and TBATS | ||
-------------- | ||
(T)BATS models [1]_ stand for | ||
* (Trigonometric) | ||
* Box-Cox | ||
* ARMA errors | ||
* Trend | ||
* Seasonal components | ||
They are appropriate to model "complex | ||
seasonal time series such as those with multiple | ||
seasonal periods, high frequency seasonality, | ||
non-integer seasonality and dual-calendar effects" [1]_. | ||
References | ||
---------- | ||
.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
from scipy.special import inv_boxcox | ||
from tbats import BATS as tbats_BATS | ||
from tbats import TBATS as tbats_TBATS | ||
|
||
from darts.logging import get_logger | ||
from darts.models.forecasting.forecasting_model import ForecastingModel | ||
from darts.timeseries import TimeSeries | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
def _seasonality_from_freq(series: TimeSeries): | ||
""" | ||
Infer a naive seasonality based on the frequency | ||
""" | ||
|
||
if series.has_range_index: | ||
return None | ||
|
||
freq = series.freq_str | ||
|
||
if freq in ["B", "C"]: | ||
return [5] | ||
elif freq == "D": | ||
return [7] | ||
elif freq == "W": | ||
return [52] | ||
elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith( | ||
("M", "BM", "BS", "CBM", "SM") | ||
): | ||
return [12] # month | ||
elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")): | ||
return [4] # quarter | ||
elif freq in ["H", "BH", "CBH"]: | ||
return [24] # hour | ||
elif freq in ["T", "min"]: | ||
return [60] # minute | ||
elif freq == "S": | ||
return [60] # second | ||
|
||
return None | ||
|
||
|
||
def _compute_samples(model, predictions, n_samples): | ||
""" | ||
This function is drawn from Model._calculate_confidence_intervals() in tbats. | ||
We have to implement our own version here in order to compute the samples before | ||
the inverse boxcox transform. | ||
""" | ||
|
||
# In the deterministic case we return the analytic mean | ||
if n_samples == 1: | ||
return np.expand_dims(predictions, axis=1) | ||
|
||
F = model.matrix.make_F_matrix() | ||
g = model.matrix.make_g_vector() | ||
w = model.matrix.make_w_vector() | ||
|
||
c = np.asarray([1.0] * len(predictions)) | ||
f_running = np.identity(F.shape[1]) | ||
for step in range(1, len(predictions)): | ||
c[step] = w @ f_running @ g | ||
f_running = f_running @ F | ||
variance_multiplier = np.cumsum(c * c) | ||
|
||
base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len( | ||
model.y | ||
) | ||
variance_boxcox = base_variance_boxcox * variance_multiplier | ||
std_boxcox = np.sqrt(variance_boxcox) | ||
|
||
# get the samples before inverse boxcoxing | ||
samples = np.random.normal( | ||
loc=model._boxcox(predictions), | ||
scale=std_boxcox, | ||
size=(n_samples, len(predictions)), | ||
).T | ||
samples = np.expand_dims(samples, axis=1) | ||
|
||
# apply inverse boxcox if needed | ||
boxcox_lambda = model.params.box_cox_lambda | ||
if boxcox_lambda is not None: | ||
samples = inv_boxcox(samples, boxcox_lambda) | ||
|
||
return samples | ||
|
||
|
||
class _BaseBatsTbatsModel(ForecastingModel, ABC): | ||
def __init__( | ||
self, | ||
use_box_cox: Optional[bool] = None, | ||
box_cox_bounds: Tuple = (0, 1), | ||
use_trend: Optional[bool] = None, | ||
use_damped_trend: Optional[bool] = None, | ||
seasonal_periods: Optional[Union[str, List]] = "freq", | ||
use_arma_errors: Optional[bool] = True, | ||
show_warnings: bool = False, | ||
n_jobs: Optional[int] = None, | ||
multiprocessing_start_method: Optional[str] = "spawn", | ||
random_state: int = 0, | ||
): | ||
|
||
""" | ||
This is a wrapper around | ||
`tbats | ||
<https://github.com/intive-DataScience/tbats>`_. | ||
This implementation also provides naive frequency inference (when "freq" | ||
is provided for ``seasonal_periods``), | ||
as well as Darts-compatible sampling of the resulting normal distribution. | ||
For convenience, the tbats documentation of the parameters is reported here. | ||
Parameters | ||
---------- | ||
use_box_cox | ||
If Box-Cox transformation of original series should be applied. | ||
When ``None`` both cases shall be considered and better is selected by AIC. | ||
box_cox_bounds | ||
Minimal and maximal Box-Cox parameter values. | ||
use_trend | ||
Indicates whether to include a trend or not. | ||
When ``None``, both cases shall be considered and the better one is selected by AIC. | ||
use_damped_trend | ||
Indicates whether to include a damping parameter in the trend or not. | ||
Applies only when trend is used. | ||
When ``None``, both cases shall be considered and the better one is selected by AIC. | ||
seasonal_periods | ||
Length of each of the periods (amount of observations in each period). | ||
TBATS accepts int and float values here. | ||
BATS accepts only int values. | ||
When ``None`` or empty array, non-seasonal model shall be fitted. | ||
If set to ``"freq"``, a single "naive" seasonality | ||
based on the series frequency will be used (e.g. [12] for monthly series). | ||
In this latter case, the seasonality will be recomputed every time the model is fit. | ||
use_arma_errors | ||
When True BATS will try to improve the model by modelling residuals with ARMA. | ||
Best model will be selected by AIC. | ||
If ``False``, ARMA residuals modeling will not be considered. | ||
show_warnings | ||
If warnings should be shown or not. | ||
n_jobs | ||
How many jobs to run in parallel when fitting BATS model. | ||
When not provided BATS shall try to utilize all available cpu cores. | ||
multiprocessing_start_method | ||
How threads should be started. | ||
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods | ||
random_state | ||
Sets the underlying random seed at model initialization time. | ||
""" | ||
super().__init__() | ||
|
||
self.kwargs = { | ||
"use_box_cox": use_box_cox, | ||
"box_cox_bounds": box_cox_bounds, | ||
"use_trend": use_trend, | ||
"use_damped_trend": use_damped_trend, | ||
"seasonal_periods": seasonal_periods, | ||
"use_arma_errors": use_arma_errors, | ||
"show_warnings": show_warnings, | ||
"n_jobs": n_jobs, | ||
"multiprocessing_start_method": multiprocessing_start_method, | ||
} | ||
|
||
self.seasonal_periods = seasonal_periods | ||
self.infer_seasonal_periods = seasonal_periods == "freq" | ||
self.model = None | ||
np.random.seed(random_state) | ||
|
||
def __str__(self): | ||
return "(T)BATS" | ||
|
||
@abstractmethod | ||
def _create_model(self): | ||
pass | ||
|
||
def fit(self, series: TimeSeries): | ||
super().fit(series) | ||
series = self.training_series | ||
|
||
if self.infer_seasonal_periods: | ||
seasonality = _seasonality_from_freq(series) | ||
self.kwargs["seasonal_periods"] = seasonality | ||
self.seasonal_periods = seasonality | ||
|
||
model = self._create_model() | ||
fitted_model = model.fit(series.values()) | ||
self.model = fitted_model | ||
|
||
return self | ||
|
||
def predict(self, n, num_samples=1): | ||
super().predict(n, num_samples) | ||
|
||
yhat = self.model.forecast(steps=n) | ||
samples = _compute_samples(self.model, yhat, num_samples) | ||
|
||
return self._build_forecast_series(samples) | ||
|
||
def _is_probabilistic(self) -> bool: | ||
return True | ||
|
||
@property | ||
def min_train_series_length(self) -> int: | ||
if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1: | ||
return 2 * self.seasonal_periods | ||
return 3 | ||
|
||
|
||
class TBATS(_BaseBatsTbatsModel): | ||
def _create_model(self): | ||
return tbats_TBATS(**self.kwargs) | ||
|
||
|
||
class BATS(_BaseBatsTbatsModel): | ||
def _create_model(self): | ||
return tbats_BATS(**self.kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
pmdarima>=1.8.0 | ||
tbats>=1.1.0 |