From 979e80d13b0f17f5275185d7c7e3ae6bb203ac66 Mon Sep 17 00:00:00 2001 From: Aman Dalmia Date: Fri, 17 Feb 2017 09:38:57 +0530 Subject: [PATCH] TST: tests updated and default value changed to None --- doc/modules/cross_validation.rst | 2 +- sklearn/model_selection/_split.py | 6 +++--- sklearn/model_selection/tests/test_split.py | 20 +++++++++++++++++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index fe21a5e4ff720..4b2a5998bb583 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -603,7 +603,7 @@ Example of 3-split time series cross-validation on a dataset with 6 samples:: >>> y = np.array([1, 2, 3, 4, 5, 6]) >>> tscv = TimeSeriesSplit(n_splits=3) >>> print(tscv) # doctest: +NORMALIZE_WHITESPACE - TimeSeriesSplit(max_train_size=0, n_splits=3) + TimeSeriesSplit(max_train_size=None, n_splits=3) >>> for train, test in tscv.split(X): ... print("%s %s" % (train, test)) [0 1 2] [3] diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 969dc2fff2fe9..cfa8c55b60e59 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -674,7 +674,7 @@ class TimeSeriesSplit(_BaseKFold): >>> y = np.array([1, 2, 3, 4]) >>> tscv = TimeSeriesSplit(n_splits=3) >>> print(tscv) # doctest: +NORMALIZE_WHITESPACE - TimeSeriesSplit(max_train_size=0, n_splits=3) + TimeSeriesSplit(max_train_size=None, n_splits=3) >>> for train_index, test_index in tscv.split(X): ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] @@ -690,7 +690,7 @@ class TimeSeriesSplit(_BaseKFold): with a test set of size ``n_samples//(n_splits + 1)``, where ``n_samples`` is the number of samples. """ - def __init__(self, n_splits=3, max_train_size=0): + def __init__(self, n_splits=3, max_train_size=None): super(TimeSeriesSplit, self).__init__(n_splits, shuffle=False, random_state=None) @@ -733,7 +733,7 @@ def split(self, X, y=None, groups=None): test_starts = range(test_size + n_samples % n_folds, n_samples, test_size) for test_start in test_starts: - if self.max_train_size > 0 and self.max_train_size < test_start: + if self.max_train_size and self.max_train_size < test_start: yield (indices[test_start - self.max_train_size:test_start], indices[test_start:test_start + test_size]) else: diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index f36824a6acd0b..6790372de8bf0 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1187,7 +1187,11 @@ def test_time_series_max_train_size(): assert_array_equal(train, [1, 2, 3]) assert_array_equal(test, [4]) - # Test for the case where the first split is less than the max_train_size + train, test = next(splits) + assert_array_equal(train, [2, 3, 4]) + assert_array_equal(test, [5]) + + # Test for the case where the size of a fold is greater than max_train_size splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X) train, test = next(splits) assert_array_equal(train, [1, 2]) @@ -1197,6 +1201,20 @@ def test_time_series_max_train_size(): assert_array_equal(train, [2, 3]) assert_array_equal(test, [4]) + # Test for the case where the size of each fold is less than max_train_size + splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X) + train, test = next(splits) + assert_array_equal(train, [0, 1, 2]) + assert_array_equal(test, [3]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 2, 3]) + assert_array_equal(test, [4]) + + train, test = next(splits) + assert_array_equal(train, [0, 1, 2, 3, 4]) + assert_array_equal(test, [5]) + def test_nested_cv(): # Test if nested cross validation works with different combinations of cv