Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add option n_splits='walk_forward' in TimeSeriesSplit #23780

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a01e12d
provided new link for `SVD based initialization` link on line 959, re…
ShehanAT Jun 19, 2022
38c78cf
added find_walk_forward_n_splits_value() to sklearn/model_selection/_…
ShehanAT Jun 28, 2022
3edd8db
fixed merge conflict in doc/modules/decomposition.rst
ShehanAT Jun 28, 2022
9764f6d
refactored find_walk_forward_n_splits_value() in sklearn/model_select…
ShehanAT Jun 28, 2022
29ff1dd
fixed grammar error in docstring in sklearn/model_selection/_split.py
ShehanAT Jun 28, 2022
4512e66
made changes suggested by @glemaitre
ShehanAT Jun 28, 2022
d1ec219
refactored sklearn/model_selection/_split.py according to @glemaitre
ShehanAT Jun 28, 2022
b5cf2cf
added two unit tests regard 'walk_forward' TimeSeriesSplit feature
ShehanAT Jun 30, 2022
ee3315a
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
ShehanAT Jun 30, 2022
52642e6
added entry to doc/whats_new/v1.2.rst
ShehanAT Jun 30, 2022
06c5b44
updated docstring to mention 'walk_forward' support for model_selecti…
ShehanAT Jun 30, 2022
a63b61d
Merge remote-tracking branch 'origin/main' into pr/ShehanAT/23780-1
glemaitre Nov 3, 2022
97f380d
DOC move the entry in changelog
glemaitre Nov 3, 2022
0cfbb7d
DOC update docstring for walk_forward
glemaitre Nov 3, 2022
47fbc01
DOC add user guide documentation
glemaitre Nov 3, 2022
481cf56
TST reformat tests
glemaitre Nov 3, 2022
9a27b48
FIX add dosctring and fix n_split
glemaitre Nov 3, 2022
7876b0a
FIX make sure to take gap into account
glemaitre Nov 3, 2022
d907960
DOC add gap into examples in user guide
glemaitre Nov 3, 2022
cc5e970
DOC fix docstring
glemaitre Nov 10, 2022
bad0678
Merge branch 'main' into walk-forward-time-series-split-ShehanAT
glemaitre May 31, 2023
8df4f5b
DOC Update to 1.3
thomasjpfan May 31, 2023
8efb3aa
Merge branch 'main' into walk-forward-time-series-split-ShehanAT
glemaitre Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 145 additions & 9 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,38 @@ class _BaseKFold(BaseCrossValidator, metaclass=ABCMeta):
"""Base class for KFold, GroupKFold, and StratifiedKFold"""

@abstractmethod
def __init__(self, n_splits, *, shuffle, random_state):
def __init__(
self,
n_splits,
*,
shuffle,
random_state,
x_shape=None,
max_train_size=None,
test_size=None,
):
if isinstance(n_splits, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, I think that we can keep the base class as is and make the base class accept a str without raising a warning.

Then, we can just specialize the get_n_splits for the TimeSeriesSplit that should not return directly self.n_splits but instead make the computation of n_splits if a string is provided.

Therefore, we only specialize the TimeSeriesSplit class.

Copy link
Contributor Author

@ShehanAT ShehanAT Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've replaced the previous method with this get_n_splits() method now.

# self.n_splits = n_splits
n_splits = int(n_splits)

if not isinstance(n_splits, numbers.Integral):
raise ValueError(
"The number of folds must be of Integral type. "
"%s of type %s was passed." % (n_splits, type(n_splits))
)
n_splits = int(n_splits)
if (
x_shape
and max_train_size
and test_size
and isinstance(n_splits, str)
and n_splits == "walk_forward"
):
n_splits = self.find_walk_forward_n_splits_value(
x_shape, max_train_size, test_size
)
else:
raise ValueError(
"The number of folds must be of Integral type. "
"%s of type %s was passed." % (n_splits, type(n_splits))
)

# n_splits = int(n_splits)

if n_splits <= 1:
raise ValueError(
Expand Down Expand Up @@ -1035,8 +1060,20 @@ class TimeSeriesSplit(_BaseKFold):
where ``n_samples`` is the number of samples.
"""

def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0):
super().__init__(n_splits, shuffle=False, random_state=None)
def __init__(
self, n_splits=5, x_shape=None, *, max_train_size=None, test_size=None, gap=0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need to have x_shape at the initialization.

Having X at the split call should be enough to get this information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed x_shape now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShehanAT Consider updating the pull request description as well.

):
if x_shape and max_train_size and test_size:
super().__init__(
n_splits,
shuffle=False,
random_state=None,
x_shape=x_shape,
max_train_size=max_train_size,
test_size=test_size,
)
else:
super().__init__(n_splits, shuffle=False, random_state=None)
self.max_train_size = max_train_size
self.test_size = test_size
self.gap = gap
Expand Down Expand Up @@ -1066,13 +1103,13 @@ def split(self, X, y=None, groups=None):
"""
X, y, groups = indexable(X, y, groups)
n_samples = _num_samples(X)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this change

Copy link
Contributor Author

@ShehanAT ShehanAT Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

n_splits = self.n_splits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is here that we can do:

n_splits = self.get_n_splits(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, added the line: n_splits = self.get_n_splits(X, y, groups) now

n_folds = n_splits + 1
gap = self.gap
test_size = (
self.test_size if self.test_size is not None else n_samples // n_folds
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this change

# Make sure we have enough samples for the given split parameters
if n_folds > n_samples:
raise ValueError(
Expand Down Expand Up @@ -1101,6 +1138,105 @@ def split(self, X, y=None, groups=None):
indices[test_start : test_start + test_size],
)

def find_walk_forward_n_splits_value(self, x_value, max_train_size, test_size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we replace this by get_n_splits, we will have the following signature:

def get_n_splits(self, X=None, y=None, groups=None)

Since we have X and self, we will have all the necessary information to compute the number of splits required to make the rolling windows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

"""
Time Series `n_splits` value calculator
Calculates the `n_splits` variable's value so that the "walk_forward"
functionality of the split time series data samples is possible.
The `n_splits` value is used to calculate the number of splits
when creating train/test indices.
While normally the `n_splits` value would be provided by the user,
entering `walk_forward` as the value for the `n_splits` parameter in the
TimeSeriesSplit constructor would call this method in order to
determine an appropriate `n_splits` value for the "walk_forward" feature.
The "walk_forward" feature is defined by any train/test indices that
have the first element starting `0` and
progreses in increasing values throughout seperate train/test indices
until the ending value equals the `x_value` value minus 1

Read more in the :ref:`User Guide <time_series_split>`.
Read more about the "walk_forward" feature
in this GitHub Issue: https://github.com/scikit-learn/scikit-learn/issues/22523

.. versionadded:: 0.18

Parameters
----------
x_value : int
First element of the np.arange array.
This is the ending value to which
the "walk forward" train/test indices go up to

.. versionchanged:: 0.22

max_train_size : int
Maximum size for a single training set.

.. versionadded:: 0.24

test_size : int
Used to limit the size of the test set. Defaults to
``n_samples // (n_splits + 1)``, which is the maximum allowed value
with ``gap=0``.

.. versionadded:: 0.24

Examples
--------------
>>> x = np.arange(15)
>>> cv = TimeSeriesSplit(n_splits="walk_forward", x_shape=x.shape[0],
max_train_size=10 ,test_size=2)
>>> for train_index, test_index in cv.split(x):
... print("TRAIN: ", train_index, "TEST: ", test_index)
...
TRAIN: [0 1 2 3 4 5 6 7 8] TEST: [ 9 10]
TRAIN: [ 1 2 3 4 5 6 7 8 9 10] TEST: [11 12]
TRAIN: [ 3 4 5 6 7 8 9 10 11 12] TEST: [13 14]

>>> x = np.arange(15)
>>> cv = TimeSeriesSplit(n_splits="walk_forward", x_shape=x.shape[0],
max_train_size=3, test_size=1)
>>> for train_index, test_index in cv.split(x):
... print("TRAIN: ", train_index, "TEST: ", test_index)
...
TRAIN: [0 1 2] TEST: [3]
TRAIN: [1 2 3] TEST: [4]
TRAIN: [2 3 4] TEST: [5]
TRAIN: [3 4 5] TEST: [6]
TRAIN: [4 5 6] TEST: [7]
TRAIN: [5 6 7] TEST: [8]
TRAIN: [6 7 8] TEST: [9]
TRAIN: [7 8 9] TEST: [10]
TRAIN: [ 8 9 10] TEST: [11]
TRAIN: [ 9 10 11] TEST: [12]
TRAIN: [10 11 12] TEST: [13]
TRAIN: [11 12 13] TEST: [14]
"""
x = np.arange(x_value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we only need to have:

def get_n_splits(self, X=None, y=None, groups=None):
    if isinstance(self.n_splits, str):
        return X.shape[0] - (self.max_train_size + self.test_size) + 1
    return self.n_splits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is a lot simpler and readable than the previous method

time_splits_storage = [
[] for i in range(x_value)
] # make the two array so that x_value defines its max length
for i in range(2, x_value, 1):
try:
cv = TimeSeriesSplit(
n_splits=i, max_train_size=max_train_size, test_size=test_size
)
for train_index, test_index in cv.split(x):
time_splits_storage[i].append([train_index, test_index])
except ValueError:
pass
n_splits_arrays_first_element_zero = []
for i in range(len(time_splits_storage)):
for j in range(len(time_splits_storage[i])):
if time_splits_storage[i][0][0][0] == 0:
n_splits_arrays_first_element_zero.append(i)
break

if len(n_splits_arrays_first_element_zero) > 0:
return n_splits_arrays_first_element_zero[0]
else:
return False


class LeaveOneGroupOut(BaseCrossValidator):
"""Leave One Group Out cross-validator
Expand Down