diff --git a/sktime/forecasting/model_selection/__init__.py b/sktime/forecasting/model_selection/__init__.py index 3e9874274fb..c092f17599f 100644 --- a/sktime/forecasting/model_selection/__init__.py +++ b/sktime/forecasting/model_selection/__init__.py @@ -9,6 +9,7 @@ "SlidingWindowSplitter", "temporal_train_test_split", "ExpandingWindowSplitter", + "TestPlusTrainSplitter", "ForecastingGridSearchCV", "ForecastingRandomizedSearchCV", "ForecastingSkoptSearchCV", @@ -19,6 +20,7 @@ ExpandingWindowSplitter, SingleWindowSplitter, SlidingWindowSplitter, + TestPlusTrainSplitter, temporal_train_test_split, ) from sktime.forecasting.model_selection._tune import ( diff --git a/sktime/forecasting/model_selection/_split.py b/sktime/forecasting/model_selection/_split.py index 27462a6c7c8..20a5a767cb9 100644 --- a/sktime/forecasting/model_selection/_split.py +++ b/sktime/forecasting/model_selection/_split.py @@ -8,6 +8,7 @@ "CutoffSplitter", "SingleWindowSplitter", "temporal_train_test_split", + "TestPlusTrainSplitter", ] __author__ = ["mloning", "kkoralturk", "khrapovs", "chillerobscuro"] @@ -1301,6 +1302,103 @@ def get_test_params(cls, parameter_set="default"): return params +class TestPlusTrainSplitter(BaseSplitter): + r"""Splitter that adds the train sets to the test sets. + + Takes a splitter ``cv`` and modifies it in the following way: + The i-th train sets is identical to the i-th train set of ``cv``. + The i-th test set is the union of the i-th train set and i-th test set of ``cv``. + + Parameters + ---------- + cv : BaseSplitter + splitter to modify as above + + Examples + -------- + >>> from sktime.datasets import load_airline + >>> from sktime.forecasting.model_selection import ExpandingWindowSplitter + + >>> y = load_airline() + >>> y_template = y[:60] + >>> cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12) + + >>> splitter = TestPlusTrainSplitter(cv_tpl) + """ + + def __init__(self, cv): + self.cv = cv + super().__init__() + + def _split(self, y: pd.Index) -> SPLIT_GENERATOR_TYPE: + """Get iloc references to train/test splits of `y`. + + private _split containing the core logic, called from split + + Parameters + ---------- + y : pd.Index or time series in sktime compatible time series format + Time series to split, or index of time series to split + + Yields + ------ + train : 1D np.ndarray of dtype int + Training window indices, iloc references to training indices in y + test : 1D np.ndarray of dtype int + Test window indices, iloc references to test indices in y + """ + cv = self.cv + + for y_train_inner, y_test_inner in cv.split(y): + y_train_self = y_train_inner + y_test_self = np.union1d(y_train_inner, y_test_inner) + yield y_train_self, y_test_self + + def get_n_splits(self, y: Optional[ACCEPTED_Y_TYPES] = None) -> int: + """Return the number of splits. + + This will always be equal to the number of splits + of ``self.cv`` on ``y``. + + Parameters + ---------- + y : pd.Series or pd.Index, optional (default=None) + Time series to split + + Returns + ------- + n_splits : int + The number of splits. + """ + return self.cv.get_n_splits(y) + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the splitter. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from sktime.forecasting.model_selection import ExpandingWindowSplitter + + cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12) + + params = {"cv": cv_tpl} + + return params + + def temporal_train_test_split( y: ACCEPTED_Y_TYPES, X: Optional[pd.DataFrame] = None,