Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
TST: tests updated and default value changed to None
  • Loading branch information
dalmia committed Feb 17, 2017
1 parent a1815b5 commit 979e80d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion doc/modules/cross_validation.rst
Expand Up @@ -603,7 +603,7 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::
>>> y = np.array([1, 2, 3, 4, 5, 6])
>>> tscv = TimeSeriesSplit(n_splits=3)
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
TimeSeriesSplit(max_train_size=0, n_splits=3)
TimeSeriesSplit(max_train_size=None, n_splits=3)
>>> for train, test in tscv.split(X):
... print("%s %s" % (train, test))
[0 1 2] [3]
Expand Down
6 changes: 3 additions & 3 deletions sklearn/model_selection/_split.py
Expand Up @@ -674,7 +674,7 @@ class TimeSeriesSplit(_BaseKFold):
>>> y = np.array([1, 2, 3, 4])
>>> tscv = TimeSeriesSplit(n_splits=3)
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
TimeSeriesSplit(max_train_size=0, n_splits=3)
TimeSeriesSplit(max_train_size=None, n_splits=3)
>>> for train_index, test_index in tscv.split(X):
... print("TRAIN:", train_index, "TEST:", test_index)
... X_train, X_test = X[train_index], X[test_index]
Expand All @@ -690,7 +690,7 @@ class TimeSeriesSplit(_BaseKFold):
with a test set of size ``n_samples//(n_splits + 1)``,
where ``n_samples`` is the number of samples.
"""
def __init__(self, n_splits=3, max_train_size=0):
def __init__(self, n_splits=3, max_train_size=None):
super(TimeSeriesSplit, self).__init__(n_splits,
shuffle=False,
random_state=None)
Expand Down Expand Up @@ -733,7 +733,7 @@ def split(self, X, y=None, groups=None):
test_starts = range(test_size + n_samples % n_folds,
n_samples, test_size)
for test_start in test_starts:
if self.max_train_size > 0 and self.max_train_size < test_start:
if self.max_train_size and self.max_train_size < test_start:
yield (indices[test_start - self.max_train_size:test_start],
indices[test_start:test_start + test_size])
else:
Expand Down
20 changes: 19 additions & 1 deletion sklearn/model_selection/tests/test_split.py
Expand Up @@ -1187,7 +1187,11 @@ def test_time_series_max_train_size():
assert_array_equal(train, [1, 2, 3])
assert_array_equal(test, [4])

# Test for the case where the first split is less than the max_train_size
train, test = next(splits)
assert_array_equal(train, [2, 3, 4])
assert_array_equal(test, [5])

# Test for the case where the size of a fold is greater than max_train_size
splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
train, test = next(splits)
assert_array_equal(train, [1, 2])
Expand All @@ -1197,6 +1201,20 @@ def test_time_series_max_train_size():
assert_array_equal(train, [2, 3])
assert_array_equal(test, [4])

# Test for the case where the size of each fold is less than max_train_size
splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
train, test = next(splits)
assert_array_equal(train, [0, 1, 2])
assert_array_equal(test, [3])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3])
assert_array_equal(test, [4])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3, 4])
assert_array_equal(test, [5])


def test_nested_cv():
# Test if nested cross validation works with different combinations of cv
Expand Down

0 comments on commit 979e80d

Please sign in to comment.