Skip to content

Commit

Permalink
Merge branch 'master' into feature/rnn-normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed May 17, 2023
2 parents 544ab44 + a7e2ec1 commit 1ded1dc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
18 changes: 13 additions & 5 deletions darts/models/forecasting/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, Tuple

import numpy as np
from statsmodels import __version_tuple__ as statsmodels_version
from statsmodels.tsa.arima.model import ARIMA as staARIMA

from darts.logging import get_logger
Expand All @@ -23,6 +24,9 @@

logger = get_logger(__name__)

# Check whether we are running statsmodels >= 0.13.5 or not:
statsmodels_above_0135 = statsmodels_version > (0, 13, 5)


class ARIMA(TransferableFutureCovariatesLocalForecastingModel):
def __init__(
Expand Down Expand Up @@ -81,11 +85,15 @@ def __init__(
self.seasonal_order = seasonal_order
self.trend = trend
self.model = None
self._random_state = (
random_state
if random_state is None
else np.random.RandomState(random_state)
)
if statsmodels_above_0135:
self._random_state = (
random_state
if random_state is None
else np.random.RandomState(random_state)
)
else:
self._random_state = None
np.random.seed(random_state if random_state is not None else 0)

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)
Expand Down
5 changes: 3 additions & 2 deletions requirements/core.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
catboost>=1.0.6
holidays>=0.11.1
catboost>=1.0.6,<1.2.0
holidays>=0.11.1,<0.25.0
joblib>=0.16.0
lightgbm>=3.2.0
matplotlib>=3.3.0
Expand All @@ -17,5 +17,6 @@ statsforecast>=1.4
statsmodels>=0.13.0
tbats>=1.1.0
tqdm>=4.60.0
typing-extensions
xarray>=0.17.0
xgboost>=1.6.0

0 comments on commit 1ded1dc

Please sign in to comment.