-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Forecast base class refactor and extension template (#912)
This PR is addressing the forecaster base class proliferation discussed in #510, by simplifying the base class inheritance tree and the streamlining the private method logic. Specifically, I've moved the logic from `_SktimeForecaster` and the forecasting horizin mixins into `BaseForecaster`, and adopted uniformly the principle of having a public method (`fit`, `predict`, etc) with checks/plumbing, dispatching to a "core logic" version (`_fit`, `_predict`, etc) where validated arguments can be assumed. The burden on extenders becomes much lighter, since it is now possible to only focus on the "core logic" implementation, instead of having to keep in mind a myriad of inconsistent and constantly shifting conventions in checks and other plumbing. As side effects, if we get this right, this should make a few things on the roadmap easier: * the extension guidelines #464. Right now, the aforementioned implicit conventions are too many and intricate to write useful extension guidelines, in my opinion. * extending to the multivariate case * input/output checks and the eternal data container discussion, this can go in the plumbing As a proof-of-concept regarding ease of extension, this PR also contains a highly annotated extension template in the `extension_templates` folder. In terms of review, the key file is `forecasting/base/_base`, with corresponding changes (contraction and deletion) in `base/_sktime`. I've tried to keep the interface consistent as much as possible (only changing internal logic). Interface contracts with all the earlier estimators are still honoured, via loopthroughs and default behaviour that ensures that everything still works if `fit`, `predict` etc are overridden by the current descendants, as opposed to `_fit` and `_predict`. The only change I had to made to descendants is set `self._is_fitted=False` at the start of `fit`, which is a minimally invasive change that's also separately reviewable as PR #941.
- Loading branch information
Showing
5 changed files
with
918 additions
and
606 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Extension template for forecasters. | ||
How to use this: | ||
- this is meant as a "fill in" template for easy extension | ||
- do NOT import this file directly - it will break | ||
- work through all the "todo" comments below | ||
- fill in code for mandatory methods, and optionally for optional methods | ||
- you can add more private methods, but do not override BaseEstimator's private methods | ||
an easy way to be safe is to prefix your methods with "_custom" | ||
- change docstrings for functions and the file | ||
- ensure interface compatibility by testing forecasting/tests/test_all_forecasters | ||
and forecasting/tests/test_sktime_forecasters | ||
- once complete: use as a local library, or contribute to sktime via PR | ||
Mandatory implements: | ||
fitting - _fit(self, y, X=None, fh=None) | ||
forecasting - _predict(self, fh=None, X=None, return_pred_int=False, | ||
alpha=DEFAULT_ALPHA) | ||
Optional implements: | ||
updating - _update(self, y, X=None, update_params=True): | ||
fitted parameter inspection - get_fitted_params() | ||
State: | ||
fitted model/strategy - by convention, any attributes ending in "_" | ||
fitted state flag - check_is_fitted() | ||
copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
""" | ||
|
||
from sktime.forecasting.base import BaseForecaster | ||
from sktime.forecasting.base._base import DEFAULT_ALPHA | ||
|
||
# todo: add any necessary imports here | ||
|
||
|
||
class MyForecaster(BaseForecaster): | ||
"""Custom forecaster. todo: write docstring. | ||
todo: describe your custom forecaster here | ||
Hyper-parameters | ||
---------------- | ||
parama : int | ||
descriptive explanation of parama | ||
paramb : string, optional (default='default') | ||
descriptive explanation of paramb | ||
paramc : boolean, optional (default= whether paramb is not the default) | ||
descriptive explanation of paramc | ||
and so on | ||
Components | ||
---------- | ||
est : sktime.estimator, BaseEstimator descendant | ||
descriptive explanation of est | ||
est2: another estimator | ||
descriptive explanation of est2 | ||
and so on | ||
""" | ||
|
||
# todo: fill out estimator tags here | ||
_tags = { | ||
"fh_in_fit": True, # is the forecasting horizon already required in fit? | ||
"handles_missing_data": False, # can the estimator handle missing data? | ||
"univariate_only": True, # can the estimator deal with multivariate series? | ||
} | ||
# in case of inheritance, concrete class should set all tags | ||
# alternatively, descendants can set tags in __init__ (avoid this if possible) | ||
|
||
# todo: add any hyper-parameters and components to constructor | ||
def __init__(self, est, parama, est2=None, paramb="default", paramc=None): | ||
# estimators should precede parameters | ||
# if estimators have default values, set None and initalize below | ||
|
||
# todo: write any hyper-parameters and components to self | ||
self.est = est | ||
self.parama = parama | ||
self.paramb = paramb | ||
self.paramc = paramc | ||
# important: no checking or other logic should happen here | ||
|
||
# todo: default estimators should have None arg defaults | ||
# and be initialized here | ||
# do this only with default estimators, not with parameters | ||
# if est2 is None: | ||
# self.estimator = MyDefaultEstimator() | ||
|
||
# todo: change "MyForecaster" to the name of the class | ||
super(MyForecaster, self).__init__() | ||
|
||
# todo: implement this, mandatory | ||
def _fit(self, y, X=None, fh=None): | ||
"""Fit forecaster to training data. | ||
core logic | ||
Parameters | ||
---------- | ||
y : pd.Series | ||
Target time series to which to fit the forecaster. | ||
fh : int, list, np.array or ForecastingHorizon, optional (default=None) | ||
The forecasters horizon with the steps ahead to to predict. | ||
X : pd.DataFrame, optional (default=None) | ||
Returns | ||
------- | ||
self : returns an instance of self. | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to y, X, fh | ||
|
||
# todo: implement this, mandatory | ||
def _predict(self, fh, X=None, return_pred_int=False, alpha=DEFAULT_ALPHA): | ||
"""Forecast time series at future horizon. | ||
core logic | ||
Parameters | ||
---------- | ||
fh : int, list, np.array or ForecastingHorizon | ||
Forecasting horizon | ||
X : pd.DataFrame, optional (default=None) | ||
Exogenous time series | ||
return_pred_int : bool, optional (default=False) | ||
If True, returns prediction intervals for given alpha values. | ||
alpha : float or list, optional (default=0.95) | ||
Returns | ||
------- | ||
y_pred : pd.Series | ||
Point predictions | ||
y_pred_int : pd.DataFrame - only if return_pred_int=True | ||
Prediction intervals | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to X, fh | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the _update method | ||
def _update(self, y, X=None, update_params=True): | ||
"""Update time series to incremental training data. | ||
core logic | ||
Parameters | ||
---------- | ||
fh : int, list, np.array or ForecastingHorizon | ||
Forecasting horizon | ||
X : pd.DataFrame, optional (default=None) | ||
Exogenous time series | ||
return_pred_int : bool, optional (default=False) | ||
If True, returns prediction intervals for given alpha values. | ||
alpha : float or list, optional (default=0.95) | ||
Returns | ||
------- | ||
y_pred : pd.Series | ||
Point predictions | ||
y_pred_int : pd.DataFrame - only if return_pred_int=True | ||
Prediction intervals | ||
State change | ||
------------ | ||
updates self._X and self._y with new data | ||
updates self.cutoff to most recent time in y | ||
if update_params=True, updates model (attributes ending in "_") | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to X, fh | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the method | ||
def _update_predict_single( | ||
self, | ||
y, | ||
fh, | ||
X=None, | ||
update_params=True, | ||
return_pred_int=False, | ||
alpha=DEFAULT_ALPHA, | ||
): | ||
"""Update forecaster and then make forecasts. | ||
Implements default behaviour of calling update and predict | ||
sequentially, but can be overwritten by subclasses | ||
to implement more efficient updating algorithms when available. | ||
""" | ||
self.update(y, X, update_params=update_params) | ||
return self.predict(fh, X, return_pred_int=return_pred_int, alpha=alpha) | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to y, X, fh | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the method | ||
def _compute_pred_int(self, alphas): | ||
"""Calculate the prediction errors for each point. | ||
Parameters | ||
---------- | ||
alpha : float or list, optional (default=0.95) | ||
A significance level or list of significance levels. | ||
Returns | ||
------- | ||
errors : list of pd.Series | ||
Each series in the list will contain the errors for each point in | ||
the forecast for the corresponding alpha. | ||
""" | ||
# implement here | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the method | ||
def _predict_moving_cutoff( | ||
self, | ||
y, | ||
cv, | ||
X=None, | ||
update_params=True, | ||
return_pred_int=False, | ||
alpha=DEFAULT_ALPHA, | ||
): | ||
"""Make single-step or multi-step moving cutoff predictions. | ||
Parameters | ||
---------- | ||
y : pd.Series | ||
cv : temporal cross-validation generator | ||
X : pd.DataFrame | ||
update_params : bool | ||
return_pred_int : bool | ||
alpha : float or array-like | ||
Returns | ||
------- | ||
y_pred = pd.Series | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to y, X, cv | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the method | ||
def get_fitted_params(self): | ||
"""Get fitted parameters. | ||
Returns | ||
------- | ||
fitted_params : dict | ||
""" | ||
# implement here |
Oops, something went wrong.