Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed Mar 8, 2022
1 parent ac9c0f9 commit 44a5fbf
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions darts/models/forecasting/tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ def __init__(
----------
use_box_cox
If Box-Cox transformation of original series should be applied.
When None both cases shall be considered and better is selected by AIC.
When ``None`` both cases shall be considered and better is selected by AIC.
box_cox_bounds
Minimal and maximal Box-Cox parameter values.
use_trend
Indicates whether to include a trend or not.
When None, both cases shall be considered and the better one is selected by AIC.
When ``None``, both cases shall be considered and the better one is selected by AIC.
use_damped_trend
Indicates whether to include a damping parameter in the trend or not.
Applies only when trend is used.
When None, both cases shall be considered and the better one is selected by AIC.
When ``None``, both cases shall be considered and the better one is selected by AIC.
seasonal_periods
Length of each of the periods (amount of observations in each period).
TBATS accepts int and float values here.
Expand All @@ -162,13 +162,13 @@ def __init__(
use_arma_errors
When True BATS will try to improve the model by modelling residuals with ARMA.
Best model will be selected by AIC.
If False, ARMA residuals modeling will not be considered.
If ``False``, ARMA residuals modeling will not be considered.
show_warnings
If warnings should be shown or not.
n_jobs: int, optional (default=None)
n_jobs
How many jobs to run in parallel when fitting BATS model.
When not provided BATS shall try to utilize all available cpu cores.
multiprocessing_start_method: str, optional (default='spawn')
multiprocessing_start_method
How threads should be started.
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
random_state
Expand Down Expand Up @@ -199,14 +199,15 @@ def __str__(self):
@abstractmethod
def _create_model(self):
pass
# return tbats_TBATS(**self.kwargs)

def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

if self.infer_seasonal_periods:
self.kwargs["seasonal_periods"] = _seasonality_from_freq(series)
seasonality = _seasonality_from_freq(series)
self.kwargs["seasonal_periods"] = seasonality
self.seasonal_periods = seasonality

model = self._create_model()
fitted_model = model.fit(series.values())
Expand Down

0 comments on commit 44a5fbf

Please sign in to comment.