-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Add option n_splits='walk_forward'
in TimeSeriesSplit
#23780
base: main
Are you sure you want to change the base?
Changes from 5 commits
a01e12d
38c78cf
3edd8db
9764f6d
29ff1dd
4512e66
d1ec219
b5cf2cf
ee3315a
52642e6
06c5b44
a63b61d
97f380d
0cfbb7d
47fbc01
481cf56
9a27b48
7876b0a
d907960
cc5e970
bad0678
8df4f5b
8efb3aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,13 +274,38 @@ class _BaseKFold(BaseCrossValidator, metaclass=ABCMeta): | |
"""Base class for KFold, GroupKFold, and StratifiedKFold""" | ||
|
||
@abstractmethod | ||
def __init__(self, n_splits, *, shuffle, random_state): | ||
def __init__( | ||
self, | ||
n_splits, | ||
*, | ||
shuffle, | ||
random_state, | ||
x_shape=None, | ||
max_train_size=None, | ||
test_size=None, | ||
): | ||
if isinstance(n_splits, int): | ||
# self.n_splits = n_splits | ||
n_splits = int(n_splits) | ||
|
||
if not isinstance(n_splits, numbers.Integral): | ||
raise ValueError( | ||
"The number of folds must be of Integral type. " | ||
"%s of type %s was passed." % (n_splits, type(n_splits)) | ||
) | ||
n_splits = int(n_splits) | ||
if ( | ||
x_shape | ||
and max_train_size | ||
and test_size | ||
and isinstance(n_splits, str) | ||
and n_splits == "walk_forward" | ||
): | ||
n_splits = self.find_walk_forward_n_splits_value( | ||
x_shape, max_train_size, test_size | ||
) | ||
else: | ||
raise ValueError( | ||
"The number of folds must be of Integral type. " | ||
"%s of type %s was passed." % (n_splits, type(n_splits)) | ||
) | ||
|
||
# n_splits = int(n_splits) | ||
|
||
if n_splits <= 1: | ||
raise ValueError( | ||
|
@@ -1035,8 +1060,20 @@ class TimeSeriesSplit(_BaseKFold): | |
where ``n_samples`` is the number of samples. | ||
""" | ||
|
||
def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0): | ||
super().__init__(n_splits, shuffle=False, random_state=None) | ||
def __init__( | ||
self, n_splits=5, x_shape=None, *, max_train_size=None, test_size=None, gap=0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not need to have Having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ShehanAT Consider updating the pull request description as well. |
||
): | ||
if x_shape and max_train_size and test_size: | ||
super().__init__( | ||
n_splits, | ||
shuffle=False, | ||
random_state=None, | ||
x_shape=x_shape, | ||
max_train_size=max_train_size, | ||
test_size=test_size, | ||
) | ||
else: | ||
super().__init__(n_splits, shuffle=False, random_state=None) | ||
self.max_train_size = max_train_size | ||
self.test_size = test_size | ||
self.gap = gap | ||
|
@@ -1066,13 +1103,13 @@ def split(self, X, y=None, groups=None): | |
""" | ||
X, y, groups = indexable(X, y, groups) | ||
n_samples = _num_samples(X) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for this change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
n_splits = self.n_splits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this is here that we can do:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, added the line: |
||
n_folds = n_splits + 1 | ||
gap = self.gap | ||
test_size = ( | ||
self.test_size if self.test_size is not None else n_samples // n_folds | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for this change |
||
# Make sure we have enough samples for the given split parameters | ||
if n_folds > n_samples: | ||
raise ValueError( | ||
|
@@ -1101,6 +1138,105 @@ def split(self, X, y=None, groups=None): | |
indices[test_start : test_start + test_size], | ||
) | ||
|
||
def find_walk_forward_n_splits_value(self, x_value, max_train_size, test_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we replace this by def get_n_splits(self, X=None, y=None, groups=None) Since we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea |
||
""" | ||
Time Series `n_splits` value calculator | ||
Calculates the `n_splits` variable's value so that the "walk_forward" | ||
functionality of the split time series data samples is possible. | ||
The `n_splits` value is used to calculate the number of splits | ||
when creating train/test indices. | ||
While normally the `n_splits` value would be provided by the user, | ||
entering `walk_forward` as the value for the `n_splits` parameter in the | ||
TimeSeriesSplit constructor would call this method in order to | ||
determine an appropriate `n_splits` value for the "walk_forward" feature. | ||
The "walk_forward" feature is defined by any train/test indices that | ||
have the first element starting `0` and | ||
progreses in increasing values throughout seperate train/test indices | ||
until the ending value equals the `x_value` value minus 1 | ||
|
||
Read more in the :ref:`User Guide <time_series_split>`. | ||
Read more about the "walk_forward" feature | ||
in this GitHub Issue: https://github.com/scikit-learn/scikit-learn/issues/22523 | ||
|
||
.. versionadded:: 0.18 | ||
|
||
Parameters | ||
---------- | ||
x_value : int | ||
First element of the np.arange array. | ||
This is the ending value to which | ||
the "walk forward" train/test indices go up to | ||
|
||
.. versionchanged:: 0.22 | ||
|
||
max_train_size : int | ||
Maximum size for a single training set. | ||
|
||
.. versionadded:: 0.24 | ||
|
||
test_size : int | ||
Used to limit the size of the test set. Defaults to | ||
``n_samples // (n_splits + 1)``, which is the maximum allowed value | ||
with ``gap=0``. | ||
|
||
.. versionadded:: 0.24 | ||
|
||
Examples | ||
-------------- | ||
>>> x = np.arange(15) | ||
>>> cv = TimeSeriesSplit(n_splits="walk_forward", x_shape=x.shape[0], | ||
max_train_size=10 ,test_size=2) | ||
>>> for train_index, test_index in cv.split(x): | ||
... print("TRAIN: ", train_index, "TEST: ", test_index) | ||
... | ||
TRAIN: [0 1 2 3 4 5 6 7 8] TEST: [ 9 10] | ||
TRAIN: [ 1 2 3 4 5 6 7 8 9 10] TEST: [11 12] | ||
TRAIN: [ 3 4 5 6 7 8 9 10 11 12] TEST: [13 14] | ||
|
||
>>> x = np.arange(15) | ||
>>> cv = TimeSeriesSplit(n_splits="walk_forward", x_shape=x.shape[0], | ||
max_train_size=3, test_size=1) | ||
>>> for train_index, test_index in cv.split(x): | ||
... print("TRAIN: ", train_index, "TEST: ", test_index) | ||
... | ||
TRAIN: [0 1 2] TEST: [3] | ||
TRAIN: [1 2 3] TEST: [4] | ||
TRAIN: [2 3 4] TEST: [5] | ||
TRAIN: [3 4 5] TEST: [6] | ||
TRAIN: [4 5 6] TEST: [7] | ||
TRAIN: [5 6 7] TEST: [8] | ||
TRAIN: [6 7 8] TEST: [9] | ||
TRAIN: [7 8 9] TEST: [10] | ||
TRAIN: [ 8 9 10] TEST: [11] | ||
TRAIN: [ 9 10 11] TEST: [12] | ||
TRAIN: [10 11 12] TEST: [13] | ||
TRAIN: [11 12 13] TEST: [14] | ||
""" | ||
x = np.arange(x_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we only need to have: def get_n_splits(self, X=None, y=None, groups=None):
if isinstance(self.n_splits, str):
return X.shape[0] - (self.max_train_size + self.test_size) + 1
return self.n_splits There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, this is a lot simpler and readable than the previous method |
||
time_splits_storage = [ | ||
[] for i in range(x_value) | ||
] # make the two array so that x_value defines its max length | ||
for i in range(2, x_value, 1): | ||
try: | ||
cv = TimeSeriesSplit( | ||
n_splits=i, max_train_size=max_train_size, test_size=test_size | ||
) | ||
for train_index, test_index in cv.split(x): | ||
time_splits_storage[i].append([train_index, test_index]) | ||
except ValueError: | ||
pass | ||
n_splits_arrays_first_element_zero = [] | ||
for i in range(len(time_splits_storage)): | ||
for j in range(len(time_splits_storage[i])): | ||
if time_splits_storage[i][0][0][0] == 0: | ||
n_splits_arrays_first_element_zero.append(i) | ||
break | ||
|
||
if len(n_splits_arrays_first_element_zero) > 0: | ||
return n_splits_arrays_first_element_zero[0] | ||
else: | ||
return False | ||
|
||
|
||
class LeaveOneGroupOut(BaseCrossValidator): | ||
"""Leave One Group Out cross-validator | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it, I think that we can keep the base class as is and make the base class accept a
str
without raising a warning.Then, we can just specialize the
get_n_splits
for theTimeSeriesSplit
that should not return directlyself.n_splits
but instead make the computation ofn_splits
if a string is provided.Therefore, we only specialize the
TimeSeriesSplit
class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've replaced the previous method with this
get_n_splits()
method now.