Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/tbats #816

Merged
merged 22 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing
from darts.models.forecasting.fft import FFT
from darts.models.forecasting.kalman_forecaster import KalmanForecaster
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA

Expand Down
243 changes: 243 additions & 0 deletions darts/models/forecasting/tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
"""
BATS and TBATS
--------------

(T)BATS models [1]_ stand for

* (Trigonometric)
* Box-Cox
* ARMA errors
* Trend
* Seasonal components

They are appropriate to model "complex
seasonal time series such as those with multiple
seasonal periods, high frequency seasonality,
non-integer seasonality and dual-calendar effects" [1]_.

References
----------
.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

import numpy as np
from scipy.special import inv_boxcox
from tbats import BATS as tbats_BATS
from tbats import TBATS as tbats_TBATS

from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.timeseries import TimeSeries

logger = get_logger(__name__)


def _seasonality_from_freq(series: TimeSeries):
hrzn marked this conversation as resolved.
Show resolved Hide resolved
"""
Infer a naive seasonality based on the frequency
"""

if series.has_range_index:
return None

freq = series.freq_str

if freq in ["B", "C"]:
return [5]
elif freq == "D":
return [7]
elif freq == "W":
return [52]
elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith(
("M", "BM", "BS", "CBM", "SM")
):
return [12] # month
elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")):
return [4] # quarter
elif freq in ["H", "BH", "CBH"]:
return [24] # hour
elif freq in ["T", "min"]:
return [60] # minute
elif freq == "S":
return [60] # second

return None


def _compute_samples(model, predictions, n_samples):
"""
This function is drawn from Model._calculate_confidence_intervals() in tbats.
We have to implement our own version here in order to compute the samples before
the inverse boxcox transform.
"""

# In the deterministic case we return the analytic mean
if n_samples == 1:
return np.expand_dims(predictions, axis=1)

F = model.matrix.make_F_matrix()
g = model.matrix.make_g_vector()
w = model.matrix.make_w_vector()

c = np.asarray([1.0] * len(predictions))
f_running = np.identity(F.shape[1])
for step in range(1, len(predictions)):
c[step] = w @ f_running @ g
f_running = f_running @ F
variance_multiplier = np.cumsum(c * c)

base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len(
model.y
)
variance_boxcox = base_variance_boxcox * variance_multiplier
std_boxcox = np.sqrt(variance_boxcox)

# get the samples before inverse boxcoxing
samples = np.random.normal(
loc=model._boxcox(predictions),
scale=std_boxcox,
size=(n_samples, len(predictions)),
).T
samples = np.expand_dims(samples, axis=1)

# apply inverse boxcox if needed
boxcox_lambda = model.params.box_cox_lambda
if boxcox_lambda is not None:
samples = inv_boxcox(samples, boxcox_lambda)

return samples


class _BaseBatsTbatsModel(ForecastingModel, ABC):
def __init__(
self,
use_box_cox: Optional[bool] = None,
box_cox_bounds: Tuple = (0, 1),
use_trend: Optional[bool] = None,
use_damped_trend: Optional[bool] = None,
seasonal_periods: Optional[Union[str, List]] = "freq",
use_arma_errors: Optional[bool] = True,
show_warnings: bool = False,
n_jobs: Optional[int] = None,
multiprocessing_start_method: Optional[str] = "spawn",
random_state: int = 0,
):

"""
This is a wrapper around
`tbats
<https://github.com/intive-DataScience/tbats>`_.

This implementation also provides naive frequency inference (when "freq"
is provided for ``seasonal_periods``),
as well as Darts-compatible sampling of the resulting normal distribution.

For convenience, the tbats documentation of the parameters is reported here.

Parameters
----------
use_box_cox
If Box-Cox transformation of original series should be applied.
When ``None`` both cases shall be considered and better is selected by AIC.
box_cox_bounds
Minimal and maximal Box-Cox parameter values.
use_trend
Indicates whether to include a trend or not.
When ``None``, both cases shall be considered and the better one is selected by AIC.
use_damped_trend
Indicates whether to include a damping parameter in the trend or not.
Applies only when trend is used.
When ``None``, both cases shall be considered and the better one is selected by AIC.
seasonal_periods
Length of each of the periods (amount of observations in each period).
TBATS accepts int and float values here.
BATS accepts only int values.
When ``None`` or empty array, non-seasonal model shall be fitted.
hrzn marked this conversation as resolved.
Show resolved Hide resolved
If set to ``"freq"``, a single "naive" seasonality
based on the series frequency will be used (e.g. [12] for monthly series).
In this latter case, the seasonality will be recomputed every time the model is fit.
use_arma_errors
When True BATS will try to improve the model by modelling residuals with ARMA.
Best model will be selected by AIC.
If ``False``, ARMA residuals modeling will not be considered.
show_warnings
If warnings should be shown or not.
n_jobs
How many jobs to run in parallel when fitting BATS model.
When not provided BATS shall try to utilize all available cpu cores.
multiprocessing_start_method
How threads should be started.
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
random_state
Sets the underlying random seed at model initialization time.
"""
super().__init__()

self.kwargs = {
"use_box_cox": use_box_cox,
"box_cox_bounds": box_cox_bounds,
"use_trend": use_trend,
"use_damped_trend": use_damped_trend,
"seasonal_periods": seasonal_periods,
"use_arma_errors": use_arma_errors,
"show_warnings": show_warnings,
"n_jobs": n_jobs,
"multiprocessing_start_method": multiprocessing_start_method,
}

self.seasonal_periods = seasonal_periods
self.infer_seasonal_periods = seasonal_periods == "freq"
self.model = None
np.random.seed(random_state)

def __str__(self):
return "(T)BATS"

@abstractmethod
def _create_model(self):
pass

def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

if self.infer_seasonal_periods:
seasonality = _seasonality_from_freq(series)
self.kwargs["seasonal_periods"] = seasonality
self.seasonal_periods = seasonality

model = self._create_model()
fitted_model = model.fit(series.values())
self.model = fitted_model

return self

def predict(self, n, num_samples=1):
super().predict(n, num_samples)

yhat = self.model.forecast(steps=n)
samples = _compute_samples(self.model, yhat, num_samples)

return self._build_forecast_series(samples)

def _is_probabilistic(self) -> bool:
return True

@property
def min_train_series_length(self) -> int:
if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1:
hrzn marked this conversation as resolved.
Show resolved Hide resolved
return 2 * self.seasonal_periods
return 3


class TBATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_TBATS(**self.kwargs)


class BATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_BATS(**self.kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@
logger.warning("Prophet not installed - will be skipping Prophet tests")

try:
from darts.models import AutoARIMA
from darts.models import BATS, TBATS, AutoARIMA

models.append((AutoARIMA(), 12.2))
models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0))
models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0))
dual_models.append(AutoARIMA())
PMDARIMA_AVAILABLE = True
except ImportError:
Expand Down
44 changes: 39 additions & 5 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@

logger = get_logger(__name__)

try:
from darts.models import BATS, TBATS

PMDARIMA_AVAILABLE = True
except ImportError:
logger.warning("pmdarima not available. BATS/TBATS probabilistic tests skipped.")
PMDARIMA_AVAILABLE = False

try:
import torch

Expand Down Expand Up @@ -46,10 +54,36 @@
TORCH_AVAILABLE = False

models_cls_kwargs_errs = [
(ExponentialSmoothing, {}, 0.4),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.17),
(ExponentialSmoothing, {}, 0.3),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03),
]

if PMDARIMA_AVAILABLE:
models_cls_kwargs_errs += [
(
BATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
(
TBATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
]

if TORCH_AVAILABLE:
models_cls_kwargs_errs += [
(
Expand Down Expand Up @@ -125,11 +159,11 @@ def test_fit_predict_determinism(self):

# whether the first predictions of two models initiated with the same random state are the same
model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred1 = model.predict(n=10, num_samples=2).values()

model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred2 = model.predict(n=10, num_samples=2).values()

self.assertTrue((pred1 == pred2).all())
Expand Down Expand Up @@ -210,7 +244,7 @@ def helper_test_probabilistic_forecast_accuracy(
(ExponentialLikelihood(), real_pos_series, 0.3, 2),
(DirichletLikelihood(), simplex_series, 0.3, 0.3),
(GeometricLikelihood(), discrete_pos_series, 1, 1),
(CauchyLikelihood(), real_series, 3, 10),
(CauchyLikelihood(), real_series, 3, 11),
(ContinuousBernoulliLikelihood(), bounded_series, 0.1, 0.1),
(HalfNormalLikelihood(), real_pos_series, 0.3, 8),
(LogNormalLikelihood(), real_pos_series, 0.3, 1),
Expand Down
1 change: 1 addition & 0 deletions requirements/pmdarima.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pmdarima>=1.8.0
tbats>=1.1.0