Skip to content

Commit

Permalink
Exemplary concrete estimator refactor post interface refactor, of Nai…
Browse files Browse the repository at this point in the history
…veForecaster (#953)

This is an exemplary refactor of a concrete estimator class, `NaiveForecaster`, to explore how general concrete forecaster refactors would work, along the lines discussed in #912.

This PR changes:

* `_BaseWindowForecaster` still inherits from `BaseForecaster` directly, and already looked extension spec compliant (no overrides, no tags)
* `NaiveForecaster` inherits from `_BaseWindowForecaster`, and has been made extension spec compliant by adding the `requires-fh-in-fit` tag, and moving core logic from `fit` to `_fit`, while avoiding to override `fit`
* it adds one line in `BaseForecaster.fit`, as a general change: the `cutoff` is set to the latest `y` index using `_set_cutoff`. This is done pre-`_fit`, which could override the set cut-off.
  • Loading branch information
fkiraly committed Jun 18, 2021
1 parent b64c3cb commit 69b5a64
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 2 additions & 0 deletions sktime/forecasting/base/_base.py
Expand Up @@ -105,6 +105,8 @@ def fit(self, y, X=None, fh=None):
self._X = X
self._y = y

self._set_cutoff(y.index[-1])

self._fit(y=y, X=X, fh=fh)

# this should happen last
Expand Down
11 changes: 4 additions & 7 deletions sktime/forecasting/naive.py
Expand Up @@ -11,12 +11,11 @@

from sktime.forecasting.base._base import DEFAULT_ALPHA
from sktime.forecasting.base._sktime import _BaseWindowForecaster
from sktime.forecasting.base._sktime import _OptionalForecastingHorizonMixin
from sktime.utils.validation.forecasting import check_sp
from sktime.utils.validation import check_window_length


class NaiveForecaster(_OptionalForecastingHorizonMixin, _BaseWindowForecaster):
class NaiveForecaster(_BaseWindowForecaster):
"""
NaiveForecaster is a forecaster that makes forecasts using simple
strategies.
Expand Down Expand Up @@ -59,13 +58,15 @@ class NaiveForecaster(_OptionalForecastingHorizonMixin, _BaseWindowForecaster):
>>> y_pred = forecaster.predict(fh=[1,2,3])
"""

_tags = {"requires-fh-in-fit": False}

def __init__(self, strategy="last", window_length=None, sp=1):
super(NaiveForecaster, self).__init__()
self.strategy = strategy
self.sp = sp
self.window_length = window_length

def fit(self, y, X=None, fh=None):
def _fit(self, y, X=None, fh=None):
"""Fit to training data.
Parameters
Expand All @@ -81,10 +82,7 @@ def fit(self, y, X=None, fh=None):
self : returns an instance of self.
"""
# X_train is ignored
self._is_fitted = False

self._set_y_X(y, X)
self._set_fh(fh)
n_timepoints = y.shape[0]

if self.strategy == "last":
Expand Down Expand Up @@ -152,7 +150,6 @@ def fit(self, y, X=None, fh=None):
f"the training series."
)

self._is_fitted = True
return self

def _predict_last_window(
Expand Down

0 comments on commit 69b5a64

Please sign in to comment.