Browse files

TEST: Tested ShuffleSplit with different types of test_size

  • Loading branch information...
1 parent 3c64e1b commit 324abdd3147bf6e379bfd5fa2a61a887ee5ba17b @davidmarek davidmarek committed with GaelVaroquaux Apr 5, 2012
Showing with 14 additions and 0 deletions.
  1. +14 −0 sklearn/tests/test_cross_validation.py
View
14 sklearn/tests/test_cross_validation.py
@@ -92,6 +92,20 @@ def test_shuffle_kfold():
assert_array_equal(all_folds, ind)
+def test_shuffle_split():
+ ss1 = cval.ShuffleSplit(10, test_size=0.2, random_state=0)
+ ss2 = cval.ShuffleSplit(10, test_size=2, random_state=0)
+ ss3 = cval.ShuffleSplit(10, test_size=np.int32(2), random_state=0)
+ ss4 = cval.ShuffleSplit(10, test_size=long(2), random_state=0)
+ for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4):
+ assert_array_equal(t1[0], t2[0])
+ assert_array_equal(t2[0], t3[0])
+ assert_array_equal(t3[0], t4[0])
+ assert_array_equal(t1[1], t2[1])
+ assert_array_equal(t2[1], t3[1])
+ assert_array_equal(t3[1], t4[1])
+
+
def test_stratified_shuffle_split():
y = np.asarray([0, 1, 1, 1, 2, 2, 2])
# Check that error is raised if there is a class with only one sample

0 comments on commit 324abdd

Please sign in to comment.