Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] splitter that replicates loc of another splitter #4851

Merged
merged 8 commits into from Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions sktime/forecasting/model_selection/__init__.py
Expand Up @@ -5,6 +5,7 @@
__author__ = ["mloning", "kkoralturk"]
__all__ = [
"CutoffSplitter",
"SameLocSplitter",
"SingleWindowSplitter",
"SlidingWindowSplitter",
"temporal_train_test_split",
Expand All @@ -16,6 +17,7 @@
from sktime.forecasting.model_selection._split import (
CutoffSplitter,
ExpandingWindowSplitter,
SameLocSplitter,
SingleWindowSplitter,
SlidingWindowSplitter,
temporal_train_test_split,
Expand Down
115 changes: 114 additions & 1 deletion sktime/forecasting/model_selection/_split.py
Expand Up @@ -7,9 +7,10 @@
"SlidingWindowSplitter",
"CutoffSplitter",
"SingleWindowSplitter",
"SameLocSplitter",
"temporal_train_test_split",
]
__author__ = ["mloning", "kkoralturk", "khrapovs", "chillerobscuro"]
__author__ = ["mloning", "kkoralturk", "khrapovs", "chillerobscuro", "fkiraly"]

from typing import Iterator, Optional, Tuple, Union

Expand Down Expand Up @@ -332,6 +333,10 @@ class BaseSplitter(BaseObject):
Single step ahead or array of steps ahead to forecast.
"""

_tags = {"split_hierarchical": False}
# split_hierarchical: whether _split supports hierarchical types natively
# if not, splitter broadcasts over instances

def __init__(
self,
fh: FORECASTING_HORIZON_TYPES = DEFAULT_FH,
Expand Down Expand Up @@ -364,6 +369,8 @@ def split(self, y: ACCEPTED_Y_TYPES) -> SPLIT_GENERATOR_TYPE:

if not isinstance(y_index, pd.MultiIndex):
split = self._split
elif self.get_tag("split_hierarchical", False, raise_error=False):
split = self._split
else:
split = self._split_vectorized

Expand Down Expand Up @@ -1301,6 +1308,112 @@ def get_test_params(cls, parameter_set="default"):
return params


class SameLocSplitter(BaseSplitter):
r"""Splitter that replicates loc indices from another splitter.

Takes a splitter ``cv`` and a time series ``y_template``.
Splits ``y`` in ``split`` and ``split_loc`` such that ``loc`` indices of splits
are identical to loc indices of ``cv`` applied to ``y_template``.

Parameters
----------
cv : BaseSplitter
splitter for which to replicate splits by ``loc`` index
y_template : time series container of ``Series`` scitype, optional
template used in ``cv`` to determine ``loc`` indices
if None, ``y_template=y`` will be used in methods

Examples
--------
>>> from sktime.datasets import load_airline
>>> from sktime.forecasting.model_selection import (
... ExpandingWindowSplitter,
... SameLocSplitter,
... )

>>> y = load_airline()
>>> y_template = y[:60]
>>> cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12)

>>> splitter = SameLocSplitter(cv_tpl, y_template)

these two are the same:
>>> list(cv_tpl.split(y_template)) # doctest: +SKIP
>>> list(splitter.split(y)) # doctest: +SKIP
"""

_tags = {"split_hierarchical": True}
# SameLocSplitter supports hierarchical pandas index

def __init__(self, cv, y_template=None):
self.cv = cv
self.y_template = y_template
super().__init__()

def _split(self, y: pd.Index) -> SPLIT_GENERATOR_TYPE:
cv = self.cv
if self.y_template is None:
y_template = y
else:
y_template = self.y_template

for y_train_loc, y_test_loc in cv.split_loc(y_template):
y_train_iloc = y.get_indexer(y_train_loc)
y_test_iloc = y.get_indexer(y_test_loc)
yield y_train_iloc, y_test_iloc

def get_n_splits(self, y: Optional[ACCEPTED_Y_TYPES] = None) -> int:
"""Return the number of splits.

Since this splitter returns a single train/test split,
this number is trivially 1.
benHeid marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
y : pd.Series or pd.Index, optional (default=None)
Time series to split

Returns
-------
n_splits : int
The number of splits.
"""
if self.y_template is None:
y_template = y
else:
y_template = self.y_template
return self.cv.get_n_splits(y_template)

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the splitter.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.

Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sktime.datasets import load_airline
from sktime.forecasting.model_selection import ExpandingWindowSplitter

y = load_airline()
y_template = y[:60]
cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12)

params = {"cv": cv_tpl, "y_template": y_template}

return params


def temporal_train_test_split(
y: ACCEPTED_Y_TYPES,
X: Optional[pd.DataFrame] = None,
Expand Down
56 changes: 56 additions & 0 deletions sktime/forecasting/model_selection/tests/test_split.py
Expand Up @@ -12,6 +12,7 @@
from sktime.forecasting.model_selection import (
CutoffSplitter,
ExpandingWindowSplitter,
SameLocSplitter,
SingleWindowSplitter,
SlidingWindowSplitter,
temporal_train_test_split,
Expand Down Expand Up @@ -568,3 +569,58 @@ def inst_index(y):
assert len(test) == 1 * n_instances
assert inst_index(train) == inst_index(y)
assert inst_index(test) == inst_index(y)


def test_same_loc_splitter():
"""Test that SameLocSplitter works as intended."""
from sktime.datasets import load_airline

y = load_airline()
y_template = y[:60]
cv_tpl = ExpandingWindowSplitter(fh=[2, 4], initial_window=24, step_length=12)

splitter = SameLocSplitter(cv_tpl, y_template)

# these should be the same
# not in general, but only because y is longer only at the end
split_template_iloc = list(cv_tpl.split(y_template))
split_templated_iloc = list(splitter.split(y))

for (t1, tt1), (t2, tt2) in zip(split_template_iloc, split_templated_iloc):
assert np.all(t1 == t2)
assert np.all(tt1 == tt2)

# these should be in general the same
split_template_loc = list(cv_tpl.split_loc(y_template))
split_templated_loc = list(splitter.split_loc(y))

for (t1, tt1), (t2, tt2) in zip(split_template_loc, split_templated_loc):
assert np.all(t1 == t2)
assert np.all(tt1 == tt2)


def test_same_loc_splitter_hierarchical():
"""Test that SameLocSplitter works as intended for hierarchical data."""
hierarchy_levels1 = (2, 2)
hierarchy_levels2 = (3, 4)
n1 = 7
n2 = 2 * n1
y_template = _make_hierarchical(
hierarchy_levels=hierarchy_levels1, max_timepoints=n1, min_timepoints=n1
)

y = _make_hierarchical(
hierarchy_levels=hierarchy_levels2, max_timepoints=n2, min_timepoints=n2
)

cv_tpl = ExpandingWindowSplitter(fh=[1, 2], initial_window=1, step_length=2)

splitter = SameLocSplitter(cv_tpl, y_template)

# these should be in general the same
split_template_loc = list(cv_tpl.split_loc(y_template))
split_templated_loc = list(splitter.split_loc(y))

for (t1, tt1), (t2, tt2) in zip(split_template_loc, split_templated_loc):
assert np.all(t1 == t2)
assert np.all(tt1 == tt2)
6 changes: 6 additions & 0 deletions sktime/registry/_tags.py
Expand Up @@ -395,6 +395,12 @@
("list", "str"),
"parameters reserved by the base class and present in all child estimators",
),
(
"split_hierarchical",
"splitter",
"bool",
"whether _split is natively implemented for hierarchical y types",
),
(
"capabilities:exact",
"distribution",
Expand Down