Skip to content

Commit

Permalink
Change the way SARIMAXModel works with regressors (#380)
Browse files Browse the repository at this point in the history
* Update _SARIMAX methods

* Update tests
  • Loading branch information
alex-hse-repository committed Feb 14, 2022
1 parent e4bfa04 commit 419515a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
34 changes: 24 additions & 10 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,24 @@ def __init__(
self.kwargs = kwargs
self._model: Optional[SARIMAX] = None
self._result: Optional[SARIMAX] = None
self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame) -> "_SARIMAXModel":
def fit(self, df: pd.DataFrame, regressors: Optional[List[str]]) -> "_SARIMAXModel":
"""
Fits a SARIMAX model.
Parameters
----------
df:
Features dataframe
regressors:
List of the columns with regressors
Returns
-------
self: SARIMAX
fitted model
"""
self.regressor_columns = regressors
categorical_cols = df.select_dtypes(include=["category"]).columns.tolist()
try:
df.loc[:, categorical_cols] = df[categorical_cols].astype(int)
Expand Down Expand Up @@ -278,27 +281,23 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
return y_pred.reset_index(drop=True, inplace=False)

def _check_df(self, df: pd.DataFrame, horizon: Optional[int] = None):
column_to_drop = [
col for col in df.columns if not col.startswith("regressor") and col not in ["target", "timestamp"]
]
regressor_columns = [col for col in df.columns if col.startswith("regressor")]
column_to_drop = [col for col in df.columns if col not in ["target", "timestamp"] + self.regressor_columns]
if column_to_drop:
warnings.warn(
message=f"SARIMAX model does not work with exogenous features (features unknown in future).\n "
f"{column_to_drop} will be dropped"
)
if horizon:
short_regressors = [regressor for regressor in regressor_columns if df[regressor].count() < horizon]
short_regressors = [regressor for regressor in self.regressor_columns if df[regressor].count() < horizon]
if short_regressors:
raise ValueError(
f"Regressors {short_regressors} are too short for chosen horizon value.\n "
"Try lower horizon value, or drop this regressors."
)

def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
regressor_columns = [col for col in df.columns if col.startswith("regressor")]
if regressor_columns:
exog_future = df[regressor_columns]
if self.regressor_columns:
exog_future = df[self.regressor_columns]
exog_future.index = df["timestamp"]
else:
exog_future = None
Expand Down Expand Up @@ -469,6 +468,21 @@ def __init__(
)
)

@log_decorator
def fit(self, ts: TSDataset) -> "SARIMAXModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment]
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self

@staticmethod
def _forecast_one_segment(
model,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_models/test_sarimax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ def test_sarimax_forecaster_run(example_tsds):
assert len(res) == 14


def test_sarimax_save_regressors_on_fit(example_reg_tsds):
model = SARIMAXModel()
model.fit(ts=example_reg_tsds)
for segment_model in model._models.values():
assert sorted(segment_model.regressor_columns) == example_reg_tsds.regressors


def test_sarimax_select_regressors_correctly(example_reg_tsds):
model = SARIMAXModel()
model.fit(ts=example_reg_tsds)
for segment, segment_model in model._models.items():
segment_features = example_reg_tsds[:, segment, :].droplevel("segment", axis=1)
segment_regressors_expected = segment_features[example_reg_tsds.regressors]
segment_regressors = segment_model._select_regressors(df=segment_features.reset_index())
assert (segment_regressors == segment_regressors_expected).all().all()


def test_sarimax_forecaster_run_with_reg(example_reg_tsds):
"""
Given: I have dataframe with 2 segments
Expand Down

0 comments on commit 419515a

Please sign in to comment.