Skip to content

Commit

Permalink
Inherit all PerSegment models from PerSegmentModel (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Feb 18, 2022
1 parent ab902ba commit 2242d40
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 37 deletions.
14 changes: 9 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521))
-
- Add plot_residuals ([#539](https://github.com/tinkoff-ai/etna/pull/539))

-
- Create `PerSegmentBaseModel`, `PerSegmentPredictionIntervalModel` ([#537](https://github.com/tinkoff-ai/etna/pull/537))
-
### Changed
- Change the way `ProphetModel` works with regressors ([#383](https://github.com/tinkoff-ai/etna/pull/383))
- Change the way `SARIMAXModel` works with regressors ([#380](https://github.com/tinkoff-ai/etna/pull/380))
- Change the way `Sklearn` models works with regressors ([#440](https://github.com/tinkoff-ai/etna/pull/440))
- Change the way `FeatureSelectionTransform` works with regressors, rename variables replacing the "regressor" to "feature" ([#522](https://github.com/tinkoff-ai/etna/pull/522))
-
-
-
-
- Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526))
-
-
- Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540))
- Update CONTRIBUTING.md ([#536](https://github.com/tinkoff-ai/etna/pull/536))

-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
### Fixed
- Fix `TSDataset._update_regressors` logic removing the regressors ([#489](https://github.com/tinkoff-ai/etna/pull/489))
- Fix `TSDataset.info`, `TSDataset.describe` methods ([#519](https://github.com/tinkoff-ai/etna/pull/519))
Expand Down
2 changes: 1 addition & 1 deletion etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def fit(self, ts: TSDataset) -> "PerSegmentBaseModel":
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features)
model.fit(df=segment_features, regressors=ts.regressors)
return self

def get_model(self) -> Dict[str, Any]:
Expand Down
11 changes: 6 additions & 5 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from typing import Optional

import numpy as np
Expand All @@ -11,7 +12,7 @@
from etna.models.base import log_decorator


class _CatBoostModel:
class _CatBoostAdapter:
def __init__(
self,
iterations: Optional[int] = None,
Expand All @@ -34,7 +35,7 @@ def __init__(
)
self._categorical = None

def fit(self, df: pd.DataFrame) -> "_CatBoostModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
features = df.drop(columns=["timestamp", "target"])
target = df["target"]
self._categorical = features.select_dtypes(include=["category"]).columns.to_list()
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(
self.thread_count = thread_count
self.kwargs = kwargs
super(CatBoostModelPerSegment, self).__init__(
base_model=_CatBoostModel(
base_model=_CatBoostAdapter(
iterations=iterations,
depth=depth,
learning_rate=learning_rate,
Expand Down Expand Up @@ -263,7 +264,7 @@ def __init__(
self.thread_count = thread_count
self.kwargs = kwargs
super(CatBoostModelMultiSegment, self).__init__()
self._base_model = _CatBoostModel(
self._base_model = _CatBoostAdapter(
iterations=iterations,
depth=depth,
learning_rate=learning_rate,
Expand All @@ -279,7 +280,7 @@ def fit(self, ts: TSDataset) -> "CatBoostModelMultiSegment":
df = ts.to_pandas(flatten=True)
df = df.dropna()
df = df.drop(columns="segment")
self._base_model.fit(df=df)
self._base_model.fit(df=df, regressors=ts.regressors)
return self

@log_decorator
Expand Down
9 changes: 5 additions & 4 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
Expand All @@ -13,7 +14,7 @@
from etna.models.base import PerSegmentModel


class _HoltWintersModel:
class _HoltWintersAdapter:
"""
Class for holding Holt-Winters' exponential smoothing model.
Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
self._model: Optional[ExponentialSmoothing] = None
self._result: Optional[HoltWintersResults] = None

def fit(self, df: pd.DataFrame) -> "_HoltWintersModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
"""
Fits a Holt-Winters' model.
Expand All @@ -179,7 +180,7 @@ def fit(self, df: pd.DataFrame) -> "_HoltWintersModel":
Returns
-------
self: _HoltWintersModel
self: _HoltWintersAdapter
fitted model
"""
self._check_df(df)
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(
self.damping_trend = damping_trend
self.fit_kwargs = fit_kwargs
super().__init__(
base_model=_HoltWintersModel(
base_model=_HoltWintersAdapter(
trend=self.trend,
damped_trend=self.damped_trend,
seasonal=self.seasonal,
Expand Down
3 changes: 2 additions & 1 deletion etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, window: int = 5, seasonality: int = 7):
self.seasonality = seasonality
self.shift = self.window * self.seasonality

def fit(self, df: pd.DataFrame) -> "_SeasonalMovingAverageModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SeasonalMovingAverageModel":
"""
Fitting simple model on given series.
Expand Down
23 changes: 4 additions & 19 deletions etna/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from etna.models.base import log_decorator


class _SklearnModel:
class _SklearnAdapter:
def __init__(self, regressor: RegressorMixin):
self.model = regressor
self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SklearnModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SklearnAdapter":
self.regressor_columns = regressors
try:
features = df[self.regressor_columns].apply(pd.to_numeric)
Expand Down Expand Up @@ -47,22 +47,7 @@ def __init__(self, regressor: RegressorMixin):
regressor:
sklearn model for regression
"""
super().__init__(base_model=_SklearnModel(regressor=regressor))

@log_decorator
def fit(self, ts: TSDataset) -> "SklearnPerSegmentModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment] # type: ignore
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self
super().__init__(base_model=_SklearnAdapter(regressor=regressor))


class SklearnMultiSegmentModel(Model):
Expand All @@ -78,7 +63,7 @@ def __init__(self, regressor: RegressorMixin):
sklearn model for regression
"""
super().__init__()
self._base_model = _SklearnModel(regressor=regressor)
self._base_model = _SklearnAdapter(regressor=regressor)

@log_decorator
def fit(self, ts: TSDataset) -> "SklearnMultiSegmentModel":
Expand Down
3 changes: 1 addition & 2 deletions tests/test_models/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def linear_segments_ts_common(random_seed):
return linear_segments_by_parameters(alpha_values, intercept_values)


@pytest.mark.xfail
@pytest.mark.parametrize("model", (LinearPerSegmentModel(), ElasticPerSegmentModel()))
def test_not_fitted(model, linear_segments_ts_unique):
"""Check exception when trying to forecast with unfitted model."""
Expand All @@ -87,7 +86,7 @@ def test_not_fitted(model, linear_segments_ts_unique):
train.fit_transform([lags])

to_forecast = train.make_future(3)
with pytest.raises(ValueError, match="model is not fitted"):
with pytest.raises(ValueError, match="not fitted model!"):
model.forecast(to_forecast)


Expand Down

0 comments on commit 2242d40

Please sign in to comment.