Skip to content

Commit

Permalink
TST n_splits and split wrapping of _CVIterableWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavrv committed Aug 13, 2015
1 parent 7f40420 commit cb082ff
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Expand Up @@ -35,6 +35,7 @@

from sklearn.model_selection._split import _safe_split
from sklearn.model_selection._split import _validate_shuffle_split
from sklearn.model_selection._split import _CVIterableWrapper

from sklearn.datasets import load_digits
from sklearn.datasets import load_iris
Expand Down Expand Up @@ -759,3 +760,19 @@ def test_check_cv_return_types():
cv2 = check_cv(OldSKF(y_multiclass, n_folds=3))
np.testing.assert_equal(list(cv1.split(X, y_multiclass)),
list(cv2.split()))


def test_cv_iterable_wrapper():
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])

with warnings.catch_warnings(record=True):
from sklearn.cross_validation import StratifiedKFold as OldSKF

cv = OldSKF(y_multiclass, n_folds=3)
wrapped_old_skf = _CVIterableWrapper(cv)

# Check if split works correctly
np.testing.assert_equal(list(cv), list(wrapped_old_skf.split()))

# Check if n_splits works correctly
assert_equal(len(cv), wrapped_old_skf.n_splits())

0 comments on commit cb082ff

Please sign in to comment.