Skip to content

Commit

Permalink
edit train/test_size default behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
nelson-liu committed Sep 19, 2016
1 parent 49fb295 commit 83a7a05
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 45 deletions.
41 changes: 28 additions & 13 deletions sklearn/cross_validation.py
Expand Up @@ -230,8 +230,8 @@ def __repr__(self):
)

def __len__(self):
return int(factorial(self.n) / factorial(self.n - self.p)
/ factorial(self.p))
return int(factorial(self.n) / factorial(self.n - self.p) /
factorial(self.p))


class _BaseKFold(with_metaclass(ABCMeta, _PartitionIterator)):
Expand Down Expand Up @@ -738,7 +738,7 @@ def __len__(self):
class BaseShuffleSplit(with_metaclass(ABCMeta)):
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""

def __init__(self, n, n_iter=10, test_size=0.1, train_size=None,
def __init__(self, n, n_iter=10, test_size=None, train_size=None,
random_state=None):
self.n = n
self.n_iter = n_iter
Expand Down Expand Up @@ -845,9 +845,8 @@ def __len__(self):

def _validate_shuffle_split(n, test_size, train_size):
if test_size is None and train_size is None:
raise ValueError(
'test_size and train_size can not both be None')

train_size = 0.9
test_size = 0.1
if test_size is not None:
if np.asarray(test_size).dtype.kind == 'f':
if test_size >= 1.:
Expand Down Expand Up @@ -881,21 +880,37 @@ def _validate_shuffle_split(n, test_size, train_size):
else:
raise ValueError("Invalid value for train_size: %r" % train_size)

if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n)
elif np.asarray(test_size).dtype.kind == 'i':
n_test = float(test_size)
if test_size is None:
# only train_size set, so set test_size as
# n - n_train
if np.asarray(train_size).dtype.kind == 'f':
n_train = floor(train_size * n)
elif np.asarray(train_size).dtype.kind == 'i':
n_train = float(train_size)

# set n_test to be the complement of n_train
n_test = n - n_train

elif train_size is None:
# only test_size was set, so set train_size as
# n - n_test
if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n)
elif np.asarray(test_size).dtype.kind == 'i':
n_test = float(test_size)

if train_size is None:
n_train = n - n_test
else:
# both train_size and test_size set, so subsample
if np.asarray(train_size).dtype.kind == 'f':
n_train = floor(train_size * n)
else:
n_train = float(train_size)

if test_size is None:
n_test = n - n_train
if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n)
else:
n_test = float(test_size)

if n_train + n_test > n:
raise ValueError('The sum of train_size and test_size = %d, '
Expand Down
54 changes: 36 additions & 18 deletions sklearn/model_selection/_split.py
Expand Up @@ -896,7 +896,7 @@ def get_n_splits(self, X, y, labels):
class BaseShuffleSplit(with_metaclass(ABCMeta)):
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""

def __init__(self, n_splits=10, test_size=0.1, train_size=None,
def __init__(self, n_splits=10, test_size=None, train_size=None,
random_state=None):
_validate_shuffle_split_init(test_size, train_size)
self.n_splits = n_splits
Expand Down Expand Up @@ -1251,9 +1251,6 @@ def _validate_shuffle_split_init(test_size, train_size):
NOTE This does not take into account the number of samples which is known
only at split
"""
if test_size is None and train_size is None:
raise ValueError('test_size and train_size can not both be None')

if test_size is not None:
if np.asarray(test_size).dtype.kind == 'f':
if test_size >= 1.:
Expand Down Expand Up @@ -1285,30 +1282,51 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
Validation helper to check if the test/test sizes are meaningful wrt to the
size of the data (n_samples)
"""
if (test_size is not None and np.asarray(test_size).dtype.kind == 'i'
and test_size >= n_samples):
if test_size is None and train_size is None:
train_size = 0.9
test_size = 0.1

if (test_size is not None and np.asarray(test_size).dtype.kind == 'i'and
test_size >= n_samples):
raise ValueError('test_size=%d should be smaller than the number of '
'samples %d' % (test_size, n_samples))

if (train_size is not None and np.asarray(train_size).dtype.kind == 'i'
and train_size >= n_samples):
if (train_size is not None and np.asarray(train_size).dtype.kind == 'i' and
train_size >= n_samples):
raise ValueError("train_size=%d should be smaller than the number of"
" samples %d" % (train_size, n_samples))

if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n_samples)
elif np.asarray(test_size).dtype.kind == 'i':
n_test = float(test_size)
if test_size is None:
# only train_size set, so set test_size as
# n - n_train
if np.asarray(train_size).dtype.kind == 'f':
n_train = floor(train_size * n_samples)
elif np.asarray(train_size).dtype.kind == 'i':
n_train = float(train_size)

# set n_test to be the complement of n_train
n_test = n_samples - n_train

elif train_size is None:
# only test_size was set, so set train_size as
# n - n_test
if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n_samples)
elif np.asarray(test_size).dtype.kind == 'i':
n_test = float(test_size)

if train_size is None:
n_train = n_samples - n_test
elif np.asarray(train_size).dtype.kind == 'f':
n_train = floor(train_size * n_samples)
else:
n_train = float(train_size)
# both train_size and test_size set, so subsample
if np.asarray(train_size).dtype.kind == 'f':
n_train = floor(train_size * n_samples)
else:
n_train = float(train_size)

if test_size is None:
n_test = n_samples - n_train
if np.asarray(test_size).dtype.kind == 'f':
n_test = ceil(test_size * n_samples)
else:
n_test = float(test_size)

if n_train + n_test > n_samples:
raise ValueError('The sum of train_size and test_size = %d, '
Expand Down
15 changes: 13 additions & 2 deletions sklearn/model_selection/tests/test_split.py
Expand Up @@ -156,7 +156,7 @@ def test_cross_validator_with_default_params():
skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
lolo_repr = "LeaveOneLabelOut()"
lopo_repr = "LeavePLabelOut(n_labels=2)"
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=0.1, "
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=None, "
"train_size=None)")
ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"

Expand Down Expand Up @@ -807,7 +807,6 @@ def train_test_split_mock_pandas():

def test_shufflesplit_errors():
# When the {test|train}_size is a float/invalid, error is raised at init
assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None)
assert_raises(ValueError, ShuffleSplit, test_size=2.0)
assert_raises(ValueError, ShuffleSplit, test_size=1.0)
assert_raises(ValueError, ShuffleSplit, test_size=0.1, train_size=0.95)
Expand All @@ -829,6 +828,18 @@ def test_shufflesplit_reproducible():
list(a for a, b in ss.split(X)))


def test_shufflesplit_train_test_size():
# check that same sequence of train-test is given
# when setting train_size to be the complement of test_size
# and vice-versa
ss_default = ShuffleSplit(random_state=0)
ss_train = ShuffleSplit(random_state=0, train_size=.9)
ss_test = ShuffleSplit(random_state=0, test_size=.1)
assert_array_equal(list(a for a, b in ss_default.split(X)),
list(a for a, b in ss_train.split(X)),
list(a for a, b in ss_test.split(X)))


def test_safe_split_with_precomputed_kernel():
clf = SVC()
clfp = SVC(kernel="precomputed")
Expand Down
36 changes: 24 additions & 12 deletions sklearn/tests/test_cross_validation.py
Expand Up @@ -24,10 +24,6 @@
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.mocking import CheckingClassifier, MockDataFrame

with warnings.catch_warnings():
warnings.simplefilter('ignore')
from sklearn import cross_validation as cval

from sklearn.datasets import make_regression
from sklearn.datasets import load_boston
from sklearn.datasets import load_digits
Expand All @@ -48,6 +44,10 @@
from sklearn.preprocessing import Imputer
from sklearn.pipeline import Pipeline

with warnings.catch_warnings():
warnings.simplefilter('ignore')
from sklearn import cross_validation as cval


class MockClassifier(object):
"""Dummy classifier to test the cross-validation"""
Expand Down Expand Up @@ -490,10 +490,11 @@ def test_stratified_shuffle_split_iter():
for train, test in sss:
assert_array_equal(np.unique(y[train]), np.unique(y[test]))
# Checks if folds keep classes proportions
p_train = (np.bincount(np.unique(y[train], return_inverse=True)[1])
/ float(len(y[train])))
p_test = (np.bincount(np.unique(y[test], return_inverse=True)[1])
/ float(len(y[test])))
p_train = (np.bincount(np.unique(y[train],
return_inverse=True)[1]) /
float(len(y[train])))
p_test = (np.bincount(np.unique(y[test], return_inverse=True)[1]) /
float(len(y[test])))
assert_array_almost_equal(p_train, p_test, 1)
assert_equal(y[train].size + y[test].size, y.size)
assert_array_equal(np.intersect1d(train, test), [])
Expand Down Expand Up @@ -862,6 +863,7 @@ def train_test_split_pandas():
assert_true(isinstance(X_train, InputFeatureType))
assert_true(isinstance(X_test, InputFeatureType))


def train_test_split_mock_pandas():
# X mock dataframe
X_df = MockDataFrame(X)
Expand Down Expand Up @@ -948,8 +950,8 @@ def test_permutation_score():

# test with custom scoring object
def custom_score(y_true, y_pred):
return (((y_true == y_pred).sum() - (y_true != y_pred).sum())
/ y_true.shape[0])
return (((y_true == y_pred).sum() - (y_true != y_pred).sum()) /
y_true.shape[0])

scorer = make_scorer(custom_score)
score, _, pvalue = cval.permutation_test_score(
Expand Down Expand Up @@ -1018,8 +1020,6 @@ def test_shufflesplit_errors():
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=10)
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=8, train_size=3)
assert_raises(ValueError, cval.ShuffleSplit, 10, train_size=1j)
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=None,
train_size=None)


def test_shufflesplit_reproducible():
Expand All @@ -1029,6 +1029,18 @@ def test_shufflesplit_reproducible():
assert_array_equal(list(a for a, b in ss), list(a for a, b in ss))


def test_shufflesplit_train_test_size():
# check that same sequence of train-test is given
# when setting train_size to be the complement of test_size
# and vice-versa
ss_default = cval.ShuffleSplit(10, random_state=0)
ss_train = cval.ShuffleSplit(10, random_state=0, train_size=.9)
ss_test = cval.ShuffleSplit(10, random_state=0, test_size=.1)
assert_array_equal(list(a for a, b in ss_default),
list(a for a, b in ss_train),
list(a for a, b in ss_test))


def test_safe_split_with_precomputed_kernel():
clf = SVC()
clfp = SVC(kernel="precomputed")
Expand Down

0 comments on commit 83a7a05

Please sign in to comment.