Skip to content

Commit

Permalink
Feat/tbats (#816)
Browse files Browse the repository at this point in the history
* Base version of TBATS

* Add both BATS and TBATS

* import both bats and tbats

* fix an issue

* added unit tests for bats and tbats probabilistic

* add accuracy unit tests

* Remove useless lines

* Improve doc

* Small fix

* Add BATS/TBATS to pmdarima flavour

* better frequency support

* Update darts/models/forecasting/tbats.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* Update darts/models/forecasting/tbats.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* some PR comments

* address PR comments

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
hrzn and dennisbader committed Mar 9, 2022
1 parent d48f265 commit 16bece0
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 6 deletions.
1 change: 1 addition & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing
from darts.models.forecasting.fft import FFT
from darts.models.forecasting.kalman_forecaster import KalmanForecaster
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA

Expand Down
243 changes: 243 additions & 0 deletions darts/models/forecasting/tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
"""
BATS and TBATS
--------------
(T)BATS models [1]_ stand for
* (Trigonometric)
* Box-Cox
* ARMA errors
* Trend
* Seasonal components
They are appropriate to model "complex
seasonal time series such as those with multiple
seasonal periods, high frequency seasonality,
non-integer seasonality and dual-calendar effects" [1]_.
References
----------
.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

import numpy as np
from scipy.special import inv_boxcox
from tbats import BATS as tbats_BATS
from tbats import TBATS as tbats_TBATS

from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.timeseries import TimeSeries

logger = get_logger(__name__)


def _seasonality_from_freq(series: TimeSeries):
"""
Infer a naive seasonality based on the frequency
"""

if series.has_range_index:
return None

freq = series.freq_str

if freq in ["B", "C"]:
return [5]
elif freq == "D":
return [7]
elif freq == "W":
return [52]
elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith(
("M", "BM", "BS", "CBM", "SM")
):
return [12] # month
elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")):
return [4] # quarter
elif freq in ["H", "BH", "CBH"]:
return [24] # hour
elif freq in ["T", "min"]:
return [60] # minute
elif freq == "S":
return [60] # second

return None


def _compute_samples(model, predictions, n_samples):
"""
This function is drawn from Model._calculate_confidence_intervals() in tbats.
We have to implement our own version here in order to compute the samples before
the inverse boxcox transform.
"""

# In the deterministic case we return the analytic mean
if n_samples == 1:
return np.expand_dims(predictions, axis=1)

F = model.matrix.make_F_matrix()
g = model.matrix.make_g_vector()
w = model.matrix.make_w_vector()

c = np.asarray([1.0] * len(predictions))
f_running = np.identity(F.shape[1])
for step in range(1, len(predictions)):
c[step] = w @ f_running @ g
f_running = f_running @ F
variance_multiplier = np.cumsum(c * c)

base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len(
model.y
)
variance_boxcox = base_variance_boxcox * variance_multiplier
std_boxcox = np.sqrt(variance_boxcox)

# get the samples before inverse boxcoxing
samples = np.random.normal(
loc=model._boxcox(predictions),
scale=std_boxcox,
size=(n_samples, len(predictions)),
).T
samples = np.expand_dims(samples, axis=1)

# apply inverse boxcox if needed
boxcox_lambda = model.params.box_cox_lambda
if boxcox_lambda is not None:
samples = inv_boxcox(samples, boxcox_lambda)

return samples


class _BaseBatsTbatsModel(ForecastingModel, ABC):
def __init__(
self,
use_box_cox: Optional[bool] = None,
box_cox_bounds: Tuple = (0, 1),
use_trend: Optional[bool] = None,
use_damped_trend: Optional[bool] = None,
seasonal_periods: Optional[Union[str, List]] = "freq",
use_arma_errors: Optional[bool] = True,
show_warnings: bool = False,
n_jobs: Optional[int] = None,
multiprocessing_start_method: Optional[str] = "spawn",
random_state: int = 0,
):

"""
This is a wrapper around
`tbats
<https://github.com/intive-DataScience/tbats>`_.
This implementation also provides naive frequency inference (when "freq"
is provided for ``seasonal_periods``),
as well as Darts-compatible sampling of the resulting normal distribution.
For convenience, the tbats documentation of the parameters is reported here.
Parameters
----------
use_box_cox
If Box-Cox transformation of original series should be applied.
When ``None`` both cases shall be considered and better is selected by AIC.
box_cox_bounds
Minimal and maximal Box-Cox parameter values.
use_trend
Indicates whether to include a trend or not.
When ``None``, both cases shall be considered and the better one is selected by AIC.
use_damped_trend
Indicates whether to include a damping parameter in the trend or not.
Applies only when trend is used.
When ``None``, both cases shall be considered and the better one is selected by AIC.
seasonal_periods
Length of each of the periods (amount of observations in each period).
TBATS accepts int and float values here.
BATS accepts only int values.
When ``None`` or empty array, non-seasonal model shall be fitted.
If set to ``"freq"``, a single "naive" seasonality
based on the series frequency will be used (e.g. [12] for monthly series).
In this latter case, the seasonality will be recomputed every time the model is fit.
use_arma_errors
When True BATS will try to improve the model by modelling residuals with ARMA.
Best model will be selected by AIC.
If ``False``, ARMA residuals modeling will not be considered.
show_warnings
If warnings should be shown or not.
n_jobs
How many jobs to run in parallel when fitting BATS model.
When not provided BATS shall try to utilize all available cpu cores.
multiprocessing_start_method
How threads should be started.
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
random_state
Sets the underlying random seed at model initialization time.
"""
super().__init__()

self.kwargs = {
"use_box_cox": use_box_cox,
"box_cox_bounds": box_cox_bounds,
"use_trend": use_trend,
"use_damped_trend": use_damped_trend,
"seasonal_periods": seasonal_periods,
"use_arma_errors": use_arma_errors,
"show_warnings": show_warnings,
"n_jobs": n_jobs,
"multiprocessing_start_method": multiprocessing_start_method,
}

self.seasonal_periods = seasonal_periods
self.infer_seasonal_periods = seasonal_periods == "freq"
self.model = None
np.random.seed(random_state)

def __str__(self):
return "(T)BATS"

@abstractmethod
def _create_model(self):
pass

def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

if self.infer_seasonal_periods:
seasonality = _seasonality_from_freq(series)
self.kwargs["seasonal_periods"] = seasonality
self.seasonal_periods = seasonality

model = self._create_model()
fitted_model = model.fit(series.values())
self.model = fitted_model

return self

def predict(self, n, num_samples=1):
super().predict(n, num_samples)

yhat = self.model.forecast(steps=n)
samples = _compute_samples(self.model, yhat, num_samples)

return self._build_forecast_series(samples)

def _is_probabilistic(self) -> bool:
return True

@property
def min_train_series_length(self) -> int:
if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1:
return 2 * self.seasonal_periods
return 3


class TBATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_TBATS(**self.kwargs)


class BATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_BATS(**self.kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@
logger.warning("Prophet not installed - will be skipping Prophet tests")

try:
from darts.models import AutoARIMA
from darts.models import BATS, TBATS, AutoARIMA

models.append((AutoARIMA(), 12.2))
models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0))
models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0))
dual_models.append(AutoARIMA())
PMDARIMA_AVAILABLE = True
except ImportError:
Expand Down
44 changes: 39 additions & 5 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@

logger = get_logger(__name__)

try:
from darts.models import BATS, TBATS

PMDARIMA_AVAILABLE = True
except ImportError:
logger.warning("pmdarima not available. BATS/TBATS probabilistic tests skipped.")
PMDARIMA_AVAILABLE = False

try:
import torch

Expand Down Expand Up @@ -46,10 +54,36 @@
TORCH_AVAILABLE = False

models_cls_kwargs_errs = [
(ExponentialSmoothing, {}, 0.4),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.17),
(ExponentialSmoothing, {}, 0.3),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03),
]

if PMDARIMA_AVAILABLE:
models_cls_kwargs_errs += [
(
BATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
(
TBATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
]

if TORCH_AVAILABLE:
models_cls_kwargs_errs += [
(
Expand Down Expand Up @@ -125,11 +159,11 @@ def test_fit_predict_determinism(self):

# whether the first predictions of two models initiated with the same random state are the same
model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred1 = model.predict(n=10, num_samples=2).values()

model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred2 = model.predict(n=10, num_samples=2).values()

self.assertTrue((pred1 == pred2).all())
Expand Down Expand Up @@ -210,7 +244,7 @@ def helper_test_probabilistic_forecast_accuracy(
(ExponentialLikelihood(), real_pos_series, 0.3, 2),
(DirichletLikelihood(), simplex_series, 0.3, 0.3),
(GeometricLikelihood(), discrete_pos_series, 1, 1),
(CauchyLikelihood(), real_series, 3, 10),
(CauchyLikelihood(), real_series, 3, 11),
(ContinuousBernoulliLikelihood(), bounded_series, 0.1, 0.1),
(HalfNormalLikelihood(), real_pos_series, 0.3, 8),
(LogNormalLikelihood(), real_pos_series, 0.3, 1),
Expand Down
1 change: 1 addition & 0 deletions requirements/pmdarima.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pmdarima>=1.8.0
tbats>=1.1.0

0 comments on commit 16bece0

Please sign in to comment.