Skip to content

Commit

Permalink
FIX Add more validation for size parameters in train_test_split (#12733)
Browse files Browse the repository at this point in the history
* Fixed typo in an example

* Removed un-needed lines from example

* Added tests on the validity of parameters

* Fixed error msg for train_size issue

* Added tests to check the validation of test and train sizes

* Switched to parameterized test

* Fixed typo in error msg

* Swithced to pytest.raises

* Swithced to pytest.raises also when checking msg

* Improved the validity tests and their unit testing
  • Loading branch information
drorata authored and adrinjalali committed Dec 19, 2018
1 parent 2aa9022 commit 440c086
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 25 deletions.
40 changes: 24 additions & 16 deletions sklearn/model_selection/_split.py
Expand Up @@ -1794,23 +1794,25 @@ def _validate_shuffle_split_init(test_size, train_size):

if test_size is not None:
if np.asarray(test_size).dtype.kind == 'f':
if test_size >= 1.:
if test_size >= 1. or test_size <= 0:
raise ValueError(
'test_size=%f should be smaller '
'than 1.0 or be an integer' % test_size)
'test_size=%f should be in the (0, 1) range '
'or be an integer' % test_size)
elif np.asarray(test_size).dtype.kind != 'i':
# int values are checked during split based on the input
raise ValueError("Invalid value for test_size: %r" % test_size)

if train_size is not None:
if np.asarray(train_size).dtype.kind == 'f':
if train_size >= 1.:
raise ValueError("train_size=%f should be smaller "
"than 1.0 or be an integer" % train_size)
if train_size >= 1. or train_size <= 0:
raise ValueError('train_size=%f should be in the (0, 1) range '
'or be an integer' % train_size)
elif (np.asarray(test_size).dtype.kind == 'f' and
(train_size + test_size) > 1.):
(
(train_size + test_size) > 1. or
(train_size + test_size) < 0)):
raise ValueError('The sum of test_size and train_size = %f, '
'should be smaller than 1.0. Reduce '
'should be in the (0, 1) range. Reduce '
'test_size and/or train_size.' %
(train_size + test_size))
elif np.asarray(train_size).dtype.kind != 'i':
Expand All @@ -1824,16 +1826,22 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
size of the data (n_samples)
"""
if (test_size is not None and
np.asarray(test_size).dtype.kind == 'i' and
test_size >= n_samples):
raise ValueError('test_size=%d should be smaller than the number of '
'samples %d' % (test_size, n_samples))
(np.asarray(test_size).dtype.kind == 'i' and
(test_size >= n_samples or test_size <= 0)) or
(np.asarray(test_size).dtype.kind == 'f' and
(test_size <= 0 or test_size >= 1))):
raise ValueError('test_size=%d should be either positive and smaller '
'than the number of samples %d or a float in the '
'(0,1) range' % (test_size, n_samples))

if (train_size is not None and
np.asarray(train_size).dtype.kind == 'i' and
train_size >= n_samples):
raise ValueError("train_size=%d should be smaller than the number of"
" samples %d" % (train_size, n_samples))
(np.asarray(train_size).dtype.kind == 'i' and
(train_size >= n_samples or train_size <= 0)) or
(np.asarray(train_size).dtype.kind == 'f' and
(train_size <= 0 or train_size >= 1))):
raise ValueError('train_size=%d should be either positive and smaller '
'than the number of samples %d or a float in the '
'(0,1) range' % (train_size, n_samples))

if test_size == "default":
test_size = 0.1
Expand Down
51 changes: 42 additions & 9 deletions sklearn/model_selection/tests/test_split.py
Expand Up @@ -1006,27 +1006,60 @@ def test_repeated_stratified_kfold_determinstic_split():


def test_train_test_split_errors():
assert_raises(ValueError, train_test_split)
pytest.raises(ValueError, train_test_split)
with warnings.catch_warnings():
# JvR: Currently, a future warning is raised if test_size is not
# given. As that is the point of this test, ignore the future warning
warnings.filterwarnings("ignore", category=FutureWarning)
assert_raises(ValueError, train_test_split, range(3), train_size=1.1)
pytest.raises(ValueError, train_test_split, range(3), train_size=1.1)

assert_raises(ValueError, train_test_split, range(3), test_size=0.6,
pytest.raises(ValueError, train_test_split, range(3), test_size=0.6,
train_size=0.6)
assert_raises(ValueError, train_test_split, range(3),
pytest.raises(ValueError, train_test_split, range(3),
test_size=np.float32(0.6), train_size=np.float32(0.6))
assert_raises(ValueError, train_test_split, range(3),
pytest.raises(ValueError, train_test_split, range(3),
test_size="wrong_type")
assert_raises(ValueError, train_test_split, range(3), test_size=2,
pytest.raises(ValueError, train_test_split, range(3), test_size=2,
train_size=4)
assert_raises(TypeError, train_test_split, range(3),
pytest.raises(TypeError, train_test_split, range(3),
some_argument=1.1)
assert_raises(ValueError, train_test_split, range(3), range(42))
assert_raises(ValueError, train_test_split, range(10),
pytest.raises(ValueError, train_test_split, range(3), range(42))
pytest.raises(ValueError, train_test_split, range(10),
shuffle=False, stratify=True)

with pytest.raises(ValueError,
match=r'train_size=11 should be either positive and '
r'smaller than the number of samples 10 or a '
r'float in the \(0,1\) range'):
train_test_split(range(10), train_size=11, test_size=1)


@pytest.mark.parametrize("train_size,test_size", [
(1.2, 0.8),
(1., 0.8),
(0.0, 0.8),
(-.2, 0.8),
(0.8, 1.2),
(0.8, 1.),
(0.8, 0.),
(0.8, -.2)])
def test_train_test_split_invalid_sizes1(train_size, test_size):
with pytest.raises(ValueError, match=r'should be in the \(0, 1\) range'):
train_test_split(range(10), train_size=train_size, test_size=test_size)


@pytest.mark.parametrize("train_size,test_size", [
(-10, 0.8),
(0, 0.8),
(11, 0.8),
(0.8, -10),
(0.8, 0),
(0.8, 11)])
def test_train_test_split_invalid_sizes2(train_size, test_size):
with pytest.raises(ValueError,
match=r'should be either positive and smaller'):
train_test_split(range(10), train_size=train_size, test_size=test_size)


def test_train_test_split():
X = np.arange(100).reshape((10, 10))
Expand Down

0 comments on commit 440c086

Please sign in to comment.