Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add option n_splits='walk_forward' in TimeSeriesSplit #23780

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a01e12d
provided new link for `SVD based initialization` link on line 959, re…
ShehanAT Jun 19, 2022
38c78cf
added find_walk_forward_n_splits_value() to sklearn/model_selection/_…
ShehanAT Jun 28, 2022
3edd8db
fixed merge conflict in doc/modules/decomposition.rst
ShehanAT Jun 28, 2022
9764f6d
refactored find_walk_forward_n_splits_value() in sklearn/model_select…
ShehanAT Jun 28, 2022
29ff1dd
fixed grammar error in docstring in sklearn/model_selection/_split.py
ShehanAT Jun 28, 2022
4512e66
made changes suggested by @glemaitre
ShehanAT Jun 28, 2022
d1ec219
refactored sklearn/model_selection/_split.py according to @glemaitre
ShehanAT Jun 28, 2022
b5cf2cf
added two unit tests regard 'walk_forward' TimeSeriesSplit feature
ShehanAT Jun 30, 2022
ee3315a
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
ShehanAT Jun 30, 2022
52642e6
added entry to doc/whats_new/v1.2.rst
ShehanAT Jun 30, 2022
06c5b44
updated docstring to mention 'walk_forward' support for model_selecti…
ShehanAT Jun 30, 2022
a63b61d
Merge remote-tracking branch 'origin/main' into pr/ShehanAT/23780-1
glemaitre Nov 3, 2022
97f380d
DOC move the entry in changelog
glemaitre Nov 3, 2022
0cfbb7d
DOC update docstring for walk_forward
glemaitre Nov 3, 2022
47fbc01
DOC add user guide documentation
glemaitre Nov 3, 2022
481cf56
TST reformat tests
glemaitre Nov 3, 2022
9a27b48
FIX add dosctring and fix n_split
glemaitre Nov 3, 2022
7876b0a
FIX make sure to take gap into account
glemaitre Nov 3, 2022
d907960
DOC add gap into examples in user guide
glemaitre Nov 3, 2022
cc5e970
DOC fix docstring
glemaitre Nov 10, 2022
bad0678
Merge branch 'main' into walk-forward-time-series-split-ShehanAT
glemaitre May 31, 2023
8df4f5b
DOC Update to 1.3
thomasjpfan May 31, 2023
8efb3aa
Merge branch 'main' into walk-forward-time-series-split-ShehanAT
glemaitre Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,19 @@ Here is a visualization of the cross-validation behavior.
:align: center
:scale: 75%

To obtain a cross-validation with a constant number of samples in the training
set in a rolling-window fashion, set the parameter `n_splits` to
`"walk_forward"`. The cross-validation splits will therefore look like::

>>> tscv = TimeSeriesSplit(
... n_splits="walk_forward", max_train_size=2, test_size=1, gap=1
... )
>>> for train, test in tscv.split(X):
... print(train, test)
[0 1] [3]
[1 2] [4]
[2 3] [5]

A note on shuffling
===================

Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,13 @@ Changelog
scores will all be set to the maximum possible rank.
:pr:`24543` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| Adds `'walk_forward'` feature to
:class:`model_selection.TimeSeriesSplit`. This feature enables rolling window
support, enabled by passing the value of `'walk_forward'` for the `n_splits`
variable.
:pr:`23780` by :user:`Sean Atukorala <ShehanAT>` and
:user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.multioutput`
..........................

Expand Down
93 changes: 74 additions & 19 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,21 @@ class _BaseKFold(BaseCrossValidator, metaclass=ABCMeta):

@abstractmethod
def __init__(self, n_splits, *, shuffle, random_state):
if not isinstance(n_splits, numbers.Integral):
raise ValueError(
"The number of folds must be of Integral type. "
"%s of type %s was passed." % (n_splits, type(n_splits))
)
n_splits = int(n_splits)

if n_splits <= 1:
if isinstance(n_splits, numbers.Integral):
if n_splits <= 1:
raise ValueError(
"k-fold cross-validation requires at least one"
" train/test split by setting n_splits=2 or more,"
f" got n_splits={n_splits}."
)
elif n_splits != "walk_forward":
raise ValueError(
"k-fold cross-validation requires at least one"
" train/test split by setting n_splits=2 or more,"
" got n_splits={0}.".format(n_splits)
"n_splits should be an integer number or 'walk_forward' for "
"the TimeSeriesSplit cross-validator."
)

if not isinstance(shuffle, bool):
raise TypeError("shuffle must be True or False; got {0}".format(shuffle))
raise TypeError(f"shuffle must be True or False; got {shuffle}")

if not shuffle and random_state is not None: # None is the default
raise ValueError(
Expand Down Expand Up @@ -995,12 +994,16 @@ class TimeSeriesSplit(_BaseKFold):

Parameters
----------
n_splits : int, default=5
Number of splits. Must be at least 2.
n_splits : "walk_forward" or int, default=5
Number of splits. Must be at least 2. If `"walk_forward"`, the number
of splits is automatically set to obtain a rolling window.

.. versionchanged:: 0.22
``n_splits`` default value changed from 3 to 5.

.. versionadded:: 1.2
Added the option `"walk_forward"` for rolling window support.

max_train_size : int, default=None
Maximum size for a single training set.

Expand Down Expand Up @@ -1077,17 +1080,46 @@ class TimeSeriesSplit(_BaseKFold):
Fold 2:
Train: index=[0 1 2 3 4 5 6 7]
Test: index=[10 11]
>>> # Showing rolling window support with via `n_splits='walk_forward'`
>>> X = np.random.randn(15, 2)
>>> tscv = TimeSeriesSplit(n_splits='walk_forward', max_train_size=10, test_size=3)
>>> for i, (train_index, test_index) in enumerate(tscv.split(X)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}")
... print(f" Test: index={test_index}")
Fold 0:
Train: index=[0 1 2]
Test: index=[3 4 5]
Fold 1:
Train: index=[0 1 2 3 4 5]
Test: index=[6 7 8]
Fold 2:
Train: index=[0 1 2 3 4 5 6 7 8]
Test: index=[ 9 10 11]
Fold 3:
Train: index=[ 2 3 4 5 6 7 8 9 10 11]
Test: index=[12 13 14]

Notes
-----
The training set has size ``i * n_samples // (n_splits + 1)
+ n_samples % (n_splits + 1)`` in the ``i`` th split,
with a test set of size ``n_samples//(n_splits + 1)`` by default,
where ``n_samples`` is the number of samples.
- The training set has size ``i * n_samples // (n_splits + 1) + n_samples %
(n_splits + 1)`` in the ``i`` th split, with a test set of size
``n_samples//(n_splits + 1)`` by default, where ``n_samples`` is the
number of samples.
- To use the rolling window support where the train set does not grow and
the `n_splits` value is automatically computed, set
`n_splits='walk_forward'`.
"""

def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0):
super().__init__(n_splits, shuffle=False, random_state=None)
if self.n_splits == "walk_forward" and (
max_train_size is None or test_size is None
):
raise ValueError(
"If `n_splits='walk_forward', then `max_train_size` and `test_size` "
"must be specified."
)
self.max_train_size = max_train_size
self.test_size = test_size
self.gap = gap
Expand Down Expand Up @@ -1117,7 +1149,7 @@ def split(self, X, y=None, groups=None):
"""
X, y, groups = indexable(X, y, groups)
n_samples = _num_samples(X)
n_splits = self.n_splits
n_splits = self.get_n_splits(X, y, groups)
n_folds = n_splits + 1
gap = self.gap
test_size = (
Expand Down Expand Up @@ -1152,6 +1184,29 @@ def split(self, X, y=None, groups=None):
indices[test_start : test_start + test_size],
)

def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Used when `n_splits='walk_forward'` to compute the number of splits.

y : object
Always ignored, exists for compatibility.

groups : object
Always ignored, exists for compatibility.

Returns
-------
n_splits : int
Returns the number of splitting iterations in the cross-validator.
"""
if self.n_splits == "walk_forward":
return (_num_samples(X) - self.max_train_size - self.gap) // self.test_size
return self.n_splits


class LeaveOneGroupOut(BaseCrossValidator):
"""Leave One Group Out cross-validator
Expand Down
68 changes: 68 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,74 @@ def test_time_series_cv():
assert n_splits_actual == 2


@pytest.mark.parametrize(
"n_splits, err_msg",
[
("walk_forward", "If `n_splits='walk_forward', then `max_train_size`"),
("unknown", "n_splits should be an integer number or 'walk_forward'"),
],
)
def test_time_series_n_splits_errors(n_splits, err_msg):
"""Check that we raise the proper error message if `n_splits` is invalid."""
with pytest.raises(ValueError, match=err_msg):
TimeSeriesSplit(n_splits=n_splits)


def test_time_series_rolling_windows():
"""Check the behaviour of the `TimeSeriesSplit` with `n_splits='walk_forward'`."""
X = np.random.randn(15, 2)

max_train_size, test_size = 3, 2
tscv = TimeSeriesSplit(
n_splits="walk_forward", max_train_size=max_train_size, test_size=test_size
)
assert tscv.get_n_splits(X) == 6

expected_splits = [
[[0, 1, 2], [3, 4]],
[[2, 3, 4], [5, 6]],
[[4, 5, 6], [7, 8]],
[[6, 7, 8], [9, 10]],
[[8, 9, 10], [11, 12]],
[[10, 11, 12], [13, 14]],
]

for i, (train_idx, test_idx) in enumerate(tscv.split(X)):
assert len(train_idx) == max_train_size
assert len(test_idx) == test_size
assert_array_equal(train_idx, expected_splits[i][0])
assert_array_equal(test_idx, expected_splits[i][1])


def test_time_series_rolling_windows_with_gap():
"""Check the behaviour of the `TimeSeriesSplit` with `n_splits='walk_forward'`
with some gap the between train and test sets."""
X = np.random.randn(15, 2)

max_train_size, test_size = 3, 2
tscv = TimeSeriesSplit(
n_splits="walk_forward",
max_train_size=max_train_size,
test_size=test_size,
gap=2,
)
assert tscv.get_n_splits(X) == 5

expected_splits = [
[[0, 1, 2], [5, 6]],
[[2, 3, 4], [7, 8]],
[[4, 5, 6], [9, 10]],
[[6, 7, 8], [11, 12]],
[[8, 9, 10], [13, 14]],
]

for i, (train_idx, test_idx) in enumerate(tscv.split(X)):
assert len(train_idx) == max_train_size
assert len(test_idx) == test_size
assert_array_equal(train_idx, expected_splits[i][0])
assert_array_equal(test_idx, expected_splits[i][1])


def _check_time_series_max_train_size(splits, check_splits, max_train_size):
for (train, test), (check_train, check_test) in zip(splits, check_splits):
assert_array_equal(test, check_test)
Expand Down