Skip to content

Commit

Permalink
&corradomio [ENH] config to turn off data memory in forecasters (#5676)
Browse files Browse the repository at this point in the history
This PR adds a boolean config field, `remember_data`, which can be used
to turn off storing of `self._X` and `self._y`.

This was requested in discussion
#5545, and is enabled by
#5590.

Open question: this does not turn off storing of `self._yvec`, which can store data.
Optimally, `VectorizedDF` also has a similar option - this is feasible,
as most logic only uses indices internally.
  • Loading branch information
fkiraly committed Jan 24, 2024
1 parent f1ae814 commit 19e249b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
30 changes: 22 additions & 8 deletions sktime/forecasting/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ class BaseForecaster(BaseEstimator):
# "joblib": uses custom joblib backend, set via `joblib_backend` tag
# "dask": uses `dask`, requires `dask` package in environment
"backend:parallel:params": None, # params for parallelization backend
"remember_data": True, # whether to remember data in fit - self._X, self._y
}

_config_doc = {
"remember_data": """
remember_data : bool, default=True
whether self._X and self._y are stored in fit, and updated
in update. If True, self._X and self._y are stored and updated.
If False, self._X and self._y are not stored and updated.
This reduces serialization size when using save,
but the update will default to "do nothing" rather than
"refit to all data seen".
""",
}

def __init__(self):
Expand Down Expand Up @@ -1144,7 +1157,7 @@ def predict_residuals(self, y=None, X=None):
fh_orig = None

# if no y is passed, the so far observed y is used
if y is None:
if y is None and self.get_config()["remember_data"]:
y = self._y

# we want residuals, so fh must be the index of y
Expand Down Expand Up @@ -1530,10 +1543,9 @@ def _check_X(self, X=None):
return self._check_X_y(X=X)[0]

def _update_X(self, X, enforce_index_type=None):
if X is not None:
if X is not None and self.get_config()["remember_data"]:
X = check_X(X, enforce_index_type=enforce_index_type)
if X is len(X) > 0:
self._X = X.combine_first(self._X)
self._X = update_data(self._X, X)

def _update_y_X(self, y, X=None, enforce_index_type=None):
"""Update internal memory of seen training data.
Expand Down Expand Up @@ -1562,7 +1574,7 @@ def _update_y_X(self, y, X=None, enforce_index_type=None):
X : pd.DataFrame or 2D np.ndarray, optional (default=None)
Exogeneous time series
"""
if y is not None:
if y is not None and self.get_config()["remember_data"]:
# unwrap y if VectorizedDF
if isinstance(y, VectorizedDF):
y = y.X_multiindex
Expand All @@ -1575,7 +1587,7 @@ def _update_y_X(self, y, X=None, enforce_index_type=None):
# set cutoff to the end of the observation horizon
self._set_cutoff_from_y(y)

if X is not None:
if X is not None and self.get_config()["remember_data"]:
# unwrap X if VectorizedDF
if isinstance(X, VectorizedDF):
X = X.X_multiindex
Expand Down Expand Up @@ -1911,7 +1923,7 @@ def _update(self, y, X=None, update_params=True):
-------
self : reference to self
"""
if update_params:
if update_params and self.get_config()["remember_data"]:
# default to re-fitting if update is not implemented
warn(
f"NotImplementedWarning: {self.__class__.__name__} "
Expand All @@ -1924,12 +1936,14 @@ def _update(self, y, X=None, update_params=True):
)
# we need to overwrite the mtype last seen and converter store, since the _y
# may have been converted
mtype_last_seen = self._y_mtype_last_seen
y_metadata = self._y_metadata
_converter_store_y = self._converter_store_y
# refit with updated data, not only passed data
self.fit(y=self._y, X=self._X, fh=self._fh)
# todo: should probably be self._fit, not self.fit
# but looping to self.fit for now to avoid interface break
self._y_mtype_last_seen = mtype_last_seen
self._y_metadata = y_metadata
self._converter_store_y = _converter_store_y

Expand Down Expand Up @@ -2164,7 +2178,7 @@ def _predict_var(self, fh=None, X=None, cov=False):
fh = fh.to_absolute(self.cutoff)
pred_var.index = fh.to_pandas()

if isinstance(self._y, pd.DataFrame):
if isinstance(pred_var, pd.DataFrame):
pred_var.columns = self._get_columns(method="predict_var")

return pred_var
Expand Down
23 changes: 23 additions & 0 deletions sktime/forecasting/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sktime.datatypes import check_is_mtype, convert
from sktime.datatypes._utilities import get_cutoff, get_window
from sktime.forecasting.arima import ARIMA
from sktime.forecasting.compose import YfromX
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.theta import ThetaForecaster
from sktime.forecasting.var import VAR
Expand Down Expand Up @@ -455,3 +456,25 @@ def test_range_fh_in_predict():

assert isinstance(var_predictions, pd.DataFrame)
assert var_predictions.shape == (10 * 2, 5)


def test_remember_data():
"""Test that the ``remember_data`` flag works as expected."""
from sktime.datasets import load_airline

y = load_airline()
X = load_airline()
f = YfromX.create_test_instance()

# turn off remembering _X, _y by config
f.set_config(**{"remember_data": False})
f.fit(y, X, fh=[1, 2, 3])

assert f._X is None
assert f._y is None

f.set_config(**{"remember_data": True})
f.fit(y, X, fh=[1, 2, 3])

assert f._X is not None
assert f._y is not None
2 changes: 1 addition & 1 deletion sktime/forecasting/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _fit(self, y, X, fh):
)

# check window length
if self.window_length_ > len(self._y):
if self.window_length_ > len(y):
param = "sp" if self.strategy == "last" and sp != 1 else "window_length_"
raise ValueError(
f"The {param}: {self.window_length_} is larger than "
Expand Down

0 comments on commit 19e249b

Please sign in to comment.