Skip to content

Commit

Permalink
[ENH] test-plus-train splitter compositor (#4862)
Browse files Browse the repository at this point in the history
This PR adds a composite splitter which takes any splitter and changes
its test splits to be the union of respective test plus train split.

Related issues and PR:
#4842
#4851
#4861
  • Loading branch information
fkiraly committed Jul 14, 2023
1 parent 00a9160 commit 0b797d7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sktime/forecasting/model_selection/__init__.py
Expand Up @@ -9,6 +9,7 @@
"SlidingWindowSplitter",
"temporal_train_test_split",
"ExpandingWindowSplitter",
"TestPlusTrainSplitter",
"ForecastingGridSearchCV",
"ForecastingRandomizedSearchCV",
"ForecastingSkoptSearchCV",
Expand All @@ -19,6 +20,7 @@
ExpandingWindowSplitter,
SingleWindowSplitter,
SlidingWindowSplitter,
TestPlusTrainSplitter,
temporal_train_test_split,
)
from sktime.forecasting.model_selection._tune import (
Expand Down
98 changes: 98 additions & 0 deletions sktime/forecasting/model_selection/_split.py
Expand Up @@ -8,6 +8,7 @@
"CutoffSplitter",
"SingleWindowSplitter",
"temporal_train_test_split",
"TestPlusTrainSplitter",
]
__author__ = ["mloning", "kkoralturk", "khrapovs", "chillerobscuro"]

Expand Down Expand Up @@ -1301,6 +1302,103 @@ def get_test_params(cls, parameter_set="default"):
return params


class TestPlusTrainSplitter(BaseSplitter):
r"""Splitter that adds the train sets to the test sets.
Takes a splitter ``cv`` and modifies it in the following way:
The i-th train sets is identical to the i-th train set of ``cv``.
The i-th test set is the union of the i-th train set and i-th test set of ``cv``.
Parameters
----------
cv : BaseSplitter
splitter to modify as above
Examples
--------
>>> from sktime.datasets import load_airline
>>> from sktime.forecasting.model_selection import ExpandingWindowSplitter
>>> y = load_airline()
>>> y_template = y[:60]
>>> cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12)
>>> splitter = TestPlusTrainSplitter(cv_tpl)
"""

def __init__(self, cv):
self.cv = cv
super().__init__()

def _split(self, y: pd.Index) -> SPLIT_GENERATOR_TYPE:
"""Get iloc references to train/test splits of `y`.
private _split containing the core logic, called from split
Parameters
----------
y : pd.Index or time series in sktime compatible time series format
Time series to split, or index of time series to split
Yields
------
train : 1D np.ndarray of dtype int
Training window indices, iloc references to training indices in y
test : 1D np.ndarray of dtype int
Test window indices, iloc references to test indices in y
"""
cv = self.cv

for y_train_inner, y_test_inner in cv.split(y):
y_train_self = y_train_inner
y_test_self = np.union1d(y_train_inner, y_test_inner)
yield y_train_self, y_test_self

def get_n_splits(self, y: Optional[ACCEPTED_Y_TYPES] = None) -> int:
"""Return the number of splits.
This will always be equal to the number of splits
of ``self.cv`` on ``y``.
Parameters
----------
y : pd.Series or pd.Index, optional (default=None)
Time series to split
Returns
-------
n_splits : int
The number of splits.
"""
return self.cv.get_n_splits(y)

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the splitter.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sktime.forecasting.model_selection import ExpandingWindowSplitter

cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12)

params = {"cv": cv_tpl}

return params


def temporal_train_test_split(
y: ACCEPTED_Y_TYPES,
X: Optional[pd.DataFrame] = None,
Expand Down

0 comments on commit 0b797d7

Please sign in to comment.