Skip to content

Change the way SARIMAXModel works with regressors #380

Merged
merged 2 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,24 @@ def __init__(
self.kwargs = kwargs
self._model: Optional[SARIMAX] = None
self._result: Optional[SARIMAX] = None
self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame) -> "_SARIMAXModel":
def fit(self, df: pd.DataFrame, regressors: Optional[List[str]]) -> "_SARIMAXModel":
"""
Fits a SARIMAX model.

Parameters
----------
df:
Features dataframe

regressors:
List of the columns with regressors
Returns
-------
self: SARIMAX
fitted model
"""
self.regressor_columns = regressors
categorical_cols = df.select_dtypes(include=["category"]).columns.tolist()
try:
df.loc[:, categorical_cols] = df[categorical_cols].astype(int)
Expand Down Expand Up @@ -288,27 +291,23 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
return y_pred.reset_index(drop=True, inplace=False)

def _check_df(self, df: pd.DataFrame, horizon: Optional[int] = None):
column_to_drop = [
col for col in df.columns if not col.startswith("regressor") and col not in ["target", "timestamp"]
]
regressor_columns = [col for col in df.columns if col.startswith("regressor")]
column_to_drop = [col for col in df.columns if col not in ["target", "timestamp"] + self.regressor_columns]
if column_to_drop:
warnings.warn(
message=f"SARIMAX model does not work with exogenous features (features unknown in future).\n "
f"{column_to_drop} will be dropped"
)
if horizon:
short_regressors = [regressor for regressor in regressor_columns if df[regressor].count() < horizon]
short_regressors = [regressor for regressor in self.regressor_columns if df[regressor].count() < horizon]
if short_regressors:
raise ValueError(
f"Regressors {short_regressors} are too short for chosen horizon value.\n "
"Try lower horizon value, or drop this regressors."
)

def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
regressor_columns = [col for col in df.columns if col.startswith("regressor")]
if regressor_columns:
exog_future = df[regressor_columns]
if self.regressor_columns:
exog_future = df[self.regressor_columns]
exog_future.index = df["timestamp"]
else:
exog_future = None
Expand Down Expand Up @@ -479,6 +478,21 @@ def __init__(
)
)

@log_decorator
def fit(self, ts: TSDataset) -> "SARIMAXModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment]
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self

@staticmethod
def _forecast_one_segment(
model,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_models/test_sarimax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ def test_sarimax_forecaster_run(example_tsds):
assert len(res) == 14


def test_sarimax_save_regressors_on_fit(example_reg_tsds):
model = SARIMAXModel()
model.fit(ts=example_reg_tsds)
for segment_model in model._models.values():
assert sorted(segment_model.regressor_columns) == example_reg_tsds.regressors


def test_sarimax_select_regressors_correctly(example_reg_tsds):
model = SARIMAXModel()
model.fit(ts=example_reg_tsds)
for segment, segment_model in model._models.items():
segment_features = example_reg_tsds[:, segment, :].droplevel("segment", axis=1)
segment_regressors_expected = segment_features[example_reg_tsds.regressors]
segment_regressors = segment_model._select_regressors(df=segment_features.reset_index())
assert (segment_regressors == segment_regressors_expected).all().all()


def test_sarimax_forecaster_run_with_reg(example_reg_tsds):
"""
Given: I have dataframe with 2 segments
Expand Down