Skip to content

Commit

Permalink
Forecast base class refactor and extension template (#912)
Browse files Browse the repository at this point in the history
This PR is addressing the forecaster base class proliferation discussed in #510, by simplifying the base class inheritance tree and the streamlining the private method logic.

Specifically, I've moved the logic from `_SktimeForecaster` and the forecasting horizin mixins into `BaseForecaster`, and adopted uniformly the principle of having a public method (`fit`, `predict`, etc) with checks/plumbing, dispatching to a "core logic" version (`_fit`, `_predict`, etc) where validated arguments can be assumed.

The burden on extenders becomes much lighter, since it is now possible to only focus on the "core logic" implementation, instead of having to keep in mind a myriad of inconsistent and constantly shifting conventions in checks and other plumbing.

As side effects, if we get this right, this should make a few things on the roadmap easier:
* the extension guidelines #464. Right now, the aforementioned implicit conventions are too many and intricate to write useful extension guidelines, in my opinion.
* extending to the multivariate case
* input/output checks and the eternal data container discussion, this can go in the plumbing

As a proof-of-concept regarding ease of extension, this PR also contains a highly annotated extension template in the `extension_templates` folder.

In terms of review, the key file is `forecasting/base/_base`, with corresponding changes (contraction and deletion) in `base/_sktime`.

I've tried to keep the interface consistent as much as possible (only changing internal logic).

Interface contracts with all the earlier estimators are still honoured, via loopthroughs and default behaviour that ensures that everything still works if `fit`, `predict` etc are overridden by the current descendants, as opposed to `_fit` and `_predict`.
The only change I had to made to descendants is set `self._is_fitted=False` at the start of `fit`, which is a minimally invasive change that's also separately reviewable as PR #941.
  • Loading branch information
fkiraly committed Jun 12, 2021
1 parent 90b1ccf commit ad82a27
Show file tree
Hide file tree
Showing 5 changed files with 918 additions and 606 deletions.
3 changes: 3 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
"profile": "https://github.com/fkiraly",
"contributions": [
"business",
"code",
"doc",
"design",
"eventOrganizing",
"financial",
"fundingFinding",
"ideas",
"mentoring",
"projectManagement",
"question",
"research",
"review",
"talk"
]
Expand Down
256 changes: 256 additions & 0 deletions extension_templates/forecasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# -*- coding: utf-8 -*-
"""
Extension template for forecasters.
How to use this:
- this is meant as a "fill in" template for easy extension
- do NOT import this file directly - it will break
- work through all the "todo" comments below
- fill in code for mandatory methods, and optionally for optional methods
- you can add more private methods, but do not override BaseEstimator's private methods
an easy way to be safe is to prefix your methods with "_custom"
- change docstrings for functions and the file
- ensure interface compatibility by testing forecasting/tests/test_all_forecasters
and forecasting/tests/test_sktime_forecasters
- once complete: use as a local library, or contribute to sktime via PR
Mandatory implements:
fitting - _fit(self, y, X=None, fh=None)
forecasting - _predict(self, fh=None, X=None, return_pred_int=False,
alpha=DEFAULT_ALPHA)
Optional implements:
updating - _update(self, y, X=None, update_params=True):
fitted parameter inspection - get_fitted_params()
State:
fitted model/strategy - by convention, any attributes ending in "_"
fitted state flag - check_is_fitted()
copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""

from sktime.forecasting.base import BaseForecaster
from sktime.forecasting.base._base import DEFAULT_ALPHA

# todo: add any necessary imports here


class MyForecaster(BaseForecaster):
"""Custom forecaster. todo: write docstring.
todo: describe your custom forecaster here
Hyper-parameters
----------------
parama : int
descriptive explanation of parama
paramb : string, optional (default='default')
descriptive explanation of paramb
paramc : boolean, optional (default= whether paramb is not the default)
descriptive explanation of paramc
and so on
Components
----------
est : sktime.estimator, BaseEstimator descendant
descriptive explanation of est
est2: another estimator
descriptive explanation of est2
and so on
"""

# todo: fill out estimator tags here
_tags = {
"fh_in_fit": True, # is the forecasting horizon already required in fit?
"handles_missing_data": False, # can the estimator handle missing data?
"univariate_only": True, # can the estimator deal with multivariate series?
}
# in case of inheritance, concrete class should set all tags
# alternatively, descendants can set tags in __init__ (avoid this if possible)

# todo: add any hyper-parameters and components to constructor
def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# estimators should precede parameters
# if estimators have default values, set None and initalize below

# todo: write any hyper-parameters and components to self
self.est = est
self.parama = parama
self.paramb = paramb
self.paramc = paramc
# important: no checking or other logic should happen here

# todo: default estimators should have None arg defaults
# and be initialized here
# do this only with default estimators, not with parameters
# if est2 is None:
# self.estimator = MyDefaultEstimator()

# todo: change "MyForecaster" to the name of the class
super(MyForecaster, self).__init__()

# todo: implement this, mandatory
def _fit(self, y, X=None, fh=None):
"""Fit forecaster to training data.
core logic
Parameters
----------
y : pd.Series
Target time series to which to fit the forecaster.
fh : int, list, np.array or ForecastingHorizon, optional (default=None)
The forecasters horizon with the steps ahead to to predict.
X : pd.DataFrame, optional (default=None)
Returns
-------
self : returns an instance of self.
"""

# implement here
# IMPORTANT: avoid side effects to y, X, fh

# todo: implement this, mandatory
def _predict(self, fh, X=None, return_pred_int=False, alpha=DEFAULT_ALPHA):
"""Forecast time series at future horizon.
core logic
Parameters
----------
fh : int, list, np.array or ForecastingHorizon
Forecasting horizon
X : pd.DataFrame, optional (default=None)
Exogenous time series
return_pred_int : bool, optional (default=False)
If True, returns prediction intervals for given alpha values.
alpha : float or list, optional (default=0.95)
Returns
-------
y_pred : pd.Series
Point predictions
y_pred_int : pd.DataFrame - only if return_pred_int=True
Prediction intervals
"""

# implement here
# IMPORTANT: avoid side effects to X, fh

# todo: consider implementing this, optional
# if not implementing, delete the _update method
def _update(self, y, X=None, update_params=True):
"""Update time series to incremental training data.
core logic
Parameters
----------
fh : int, list, np.array or ForecastingHorizon
Forecasting horizon
X : pd.DataFrame, optional (default=None)
Exogenous time series
return_pred_int : bool, optional (default=False)
If True, returns prediction intervals for given alpha values.
alpha : float or list, optional (default=0.95)
Returns
-------
y_pred : pd.Series
Point predictions
y_pred_int : pd.DataFrame - only if return_pred_int=True
Prediction intervals
State change
------------
updates self._X and self._y with new data
updates self.cutoff to most recent time in y
if update_params=True, updates model (attributes ending in "_")
"""

# implement here
# IMPORTANT: avoid side effects to X, fh

# todo: consider implementing this, optional
# if not implementing, delete the method
def _update_predict_single(
self,
y,
fh,
X=None,
update_params=True,
return_pred_int=False,
alpha=DEFAULT_ALPHA,
):
"""Update forecaster and then make forecasts.
Implements default behaviour of calling update and predict
sequentially, but can be overwritten by subclasses
to implement more efficient updating algorithms when available.
"""
self.update(y, X, update_params=update_params)
return self.predict(fh, X, return_pred_int=return_pred_int, alpha=alpha)

# implement here
# IMPORTANT: avoid side effects to y, X, fh

# todo: consider implementing this, optional
# if not implementing, delete the method
def _compute_pred_int(self, alphas):
"""Calculate the prediction errors for each point.
Parameters
----------
alpha : float or list, optional (default=0.95)
A significance level or list of significance levels.
Returns
-------
errors : list of pd.Series
Each series in the list will contain the errors for each point in
the forecast for the corresponding alpha.
"""
# implement here

# todo: consider implementing this, optional
# if not implementing, delete the method
def _predict_moving_cutoff(
self,
y,
cv,
X=None,
update_params=True,
return_pred_int=False,
alpha=DEFAULT_ALPHA,
):
"""Make single-step or multi-step moving cutoff predictions.
Parameters
----------
y : pd.Series
cv : temporal cross-validation generator
X : pd.DataFrame
update_params : bool
return_pred_int : bool
alpha : float or array-like
Returns
-------
y_pred = pd.Series
"""

# implement here
# IMPORTANT: avoid side effects to y, X, cv

# todo: consider implementing this, optional
# if not implementing, delete the method
def get_fitted_params(self):
"""Get fitted parameters.
Returns
-------
fitted_params : dict
"""
# implement here
Loading

0 comments on commit ad82a27

Please sign in to comment.