Skip to content

Commit 31a4691

Browse files
amuellerogrisel
authored andcommitted
[MRG+1] fix sampling in stratified shuffle split (#6472)
Fix sampling in stratified shuffle split, break tests that test sampling.
1 parent 49d126f commit 31a4691

File tree

6 files changed

+189
-69
lines changed

6 files changed

+189
-69
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,11 @@ Bug fixes
397397
- Fix :class:`linear_model.ElasticNet` sparse decision function to match
398398
output with dense in the multioutput case.
399399

400+
- Fix in :class:`sklearn.model_selection.StratifiedShuffleSplit` to
401+
return splits of size ``train_size`` and ``test_size`` in all cases
402+
(`#6472 <https://github.com/scikit-learn/scikit-learn/pull/6472>`).
403+
By `Andreas Müller`_.
404+
400405
API changes summary
401406
-------------------
402407

sklearn/cross_validation.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .utils.validation import (_is_arraylike, _num_samples,
2828
column_or_1d)
2929
from .utils.multiclass import type_of_target
30+
from .utils.random import choice
3031
from .externals.joblib import Parallel, delayed, logger
3132
from .externals.six import with_metaclass
3233
from .externals.six.moves import zip
@@ -414,9 +415,9 @@ def __init__(self, labels, n_folds=3):
414415

415416
if n_folds > n_labels:
416417
raise ValueError(
417-
("Cannot have number of folds n_folds={0} greater"
418-
" than the number of labels: {1}.").format(n_folds,
419-
n_labels))
418+
("Cannot have number of folds n_folds={0} greater"
419+
" than the number of labels: {1}.").format(n_folds,
420+
n_labels))
420421

421422
# Weight labels by their number of occurrences
422423
n_samples_per_label = np.bincount(labels)
@@ -906,6 +907,59 @@ def _validate_shuffle_split(n, test_size, train_size):
906907
return int(n_train), int(n_test)
907908

908909

910+
def _approximate_mode(class_counts, n_draws, rng):
911+
"""Computes approximate mode of multivariate hypergeometric.
912+
913+
This is an approximation to the mode of the multivariate
914+
hypergeometric given by class_counts and n_draws.
915+
It shouldn't be off by more than one.
916+
917+
It is the mostly likely outcome of drawing n_draws many
918+
samples from the population given by class_counts.
919+
920+
Parameters
921+
----------
922+
class_counts : ndarray of int
923+
Population per class.
924+
n_draws : int
925+
Number of draws (samples to draw) from the overall population.
926+
rng : random state
927+
Used to break ties.
928+
929+
Returns
930+
-------
931+
sampled_classes : ndarray of int
932+
Number of samples drawn from each class.
933+
np.sum(sampled_classes) == n_draws
934+
"""
935+
# this computes a bad approximation to the mode of the
936+
# multivariate hypergeometric given by class_counts and n_draws
937+
continuous = n_draws * class_counts / class_counts.sum()
938+
# floored means we don't overshoot n_samples, but probably undershoot
939+
floored = np.floor(continuous)
940+
# we add samples according to how much "left over" probability
941+
# they had, until we arrive at n_samples
942+
need_to_add = int(n_draws - floored.sum())
943+
if need_to_add > 0:
944+
remainder = continuous - floored
945+
values = np.sort(np.unique(remainder))[::-1]
946+
# add according to remainder, but break ties
947+
# randomly to avoid biases
948+
for value in values:
949+
inds, = np.where(remainder == value)
950+
# if we need_to_add less than what's in inds
951+
# we draw randomly from them.
952+
# if we need to add more, we add them all and
953+
# go to the next value
954+
add_now = min(len(inds), need_to_add)
955+
inds = choice(inds, size=add_now, replace=False, random_state=rng)
956+
floored[inds] += 1
957+
need_to_add -= add_now
958+
if need_to_add == 0:
959+
break
960+
return floored.astype(np.int)
961+
962+
909963
class StratifiedShuffleSplit(BaseShuffleSplit):
910964
"""Stratified ShuffleSplit cross validation iterator
911965
@@ -991,39 +1045,24 @@ def __init__(self, y, n_iter=10, test_size=0.1, train_size=None,
9911045
def _iter_indices(self):
9921046
rng = check_random_state(self.random_state)
9931047
cls_count = bincount(self.y_indices)
994-
p_i = cls_count / float(self.n)
995-
n_i = np.round(self.n_train * p_i).astype(int)
996-
t_i = np.minimum(cls_count - n_i,
997-
np.round(self.n_test * p_i).astype(int))
9981048

9991049
for n in range(self.n_iter):
1050+
# if there are ties in the class-counts, we want
1051+
# to make sure to break them anew in each iteration
1052+
n_i = _approximate_mode(cls_count, self.n_train, rng)
1053+
class_counts_remaining = cls_count - n_i
1054+
t_i = _approximate_mode(class_counts_remaining, self.n_test, rng)
1055+
10001056
train = []
10011057
test = []
10021058

1003-
for i, cls in enumerate(self.classes):
1059+
for i, _ in enumerate(self.classes):
10041060
permutation = rng.permutation(cls_count[i])
1005-
cls_i = np.where((self.y == cls))[0][permutation]
1006-
1007-
train.extend(cls_i[:n_i[i]])
1008-
test.extend(cls_i[n_i[i]:n_i[i] + t_i[i]])
1009-
1010-
# Because of rounding issues (as n_train and n_test are not
1011-
# dividers of the number of elements per class), we may end
1012-
# up here with less samples in train and test than asked for.
1013-
if len(train) + len(test) < self.n_train + self.n_test:
1014-
# We complete by affecting randomly the missing indexes
1015-
missing_idx = np.where(bincount(train + test,
1016-
minlength=len(self.y)) == 0,
1017-
)[0]
1018-
missing_idx = rng.permutation(missing_idx)
1019-
n_missing_train = self.n_train - len(train)
1020-
n_missing_test = self.n_test - len(test)
1021-
1022-
if n_missing_train > 0:
1023-
train.extend(missing_idx[:n_missing_train])
1024-
if n_missing_test > 0:
1025-
test.extend(missing_idx[-n_missing_test:])
1061+
perm_indices_class_i = np.where(
1062+
(i == self.y_indices))[0][permutation]
10261063

1064+
train.extend(perm_indices_class_i[:n_i[i]])
1065+
test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
10271066
train = rng.permutation(train)
10281067
test = rng.permutation(test)
10291068

sklearn/model_selection/_split.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..externals.six.moves import zip
3131
from ..utils.fixes import bincount
3232
from ..utils.fixes import signature
33+
from ..utils.random import choice
3334
from ..base import _pprint
3435
from ..gaussian_process.kernels import Kernel as GPKernel
3536

@@ -1098,6 +1099,73 @@ def _iter_indices(self, X, y, labels):
10981099
yield train, test
10991100

11001101

1102+
def _approximate_mode(class_counts, n_draws, rng):
1103+
"""Computes approximate mode of multivariate hypergeometric.
1104+
1105+
This is an approximation to the mode of the multivariate
1106+
hypergeometric given by class_counts and n_draws.
1107+
It shouldn't be off by more than one.
1108+
1109+
It is the mostly likely outcome of drawing n_draws many
1110+
samples from the population given by class_counts.
1111+
1112+
Parameters
1113+
----------
1114+
class_counts : ndarray of int
1115+
Population per class.
1116+
n_draws : int
1117+
Number of draws (samples to draw) from the overall population.
1118+
rng : random state
1119+
Used to break ties.
1120+
1121+
Returns
1122+
-------
1123+
sampled_classes : ndarray of int
1124+
Number of samples drawn from each class.
1125+
np.sum(sampled_classes) == n_draws
1126+
1127+
Examples
1128+
--------
1129+
>>> from sklearn.model_selection._split import _approximate_mode
1130+
>>> _approximate_mode(class_counts=np.array([4, 2]), n_draws=3, rng=0)
1131+
array([2, 1])
1132+
>>> _approximate_mode(class_counts=np.array([5, 2]), n_draws=4, rng=0)
1133+
array([3, 1])
1134+
>>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
1135+
... n_draws=2, rng=0)
1136+
array([0, 1, 1, 0])
1137+
>>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
1138+
... n_draws=2, rng=42)
1139+
array([1, 1, 0, 0])
1140+
"""
1141+
# this computes a bad approximation to the mode of the
1142+
# multivariate hypergeometric given by class_counts and n_draws
1143+
continuous = n_draws * class_counts / class_counts.sum()
1144+
# floored means we don't overshoot n_samples, but probably undershoot
1145+
floored = np.floor(continuous)
1146+
# we add samples according to how much "left over" probability
1147+
# they had, until we arrive at n_samples
1148+
need_to_add = int(n_draws - floored.sum())
1149+
if need_to_add > 0:
1150+
remainder = continuous - floored
1151+
values = np.sort(np.unique(remainder))[::-1]
1152+
# add according to remainder, but break ties
1153+
# randomly to avoid biases
1154+
for value in values:
1155+
inds, = np.where(remainder == value)
1156+
# if we need_to_add less than what's in inds
1157+
# we draw randomly from them.
1158+
# if we need to add more, we add them all and
1159+
# go to the next value
1160+
add_now = min(len(inds), need_to_add)
1161+
inds = choice(inds, size=add_now, replace=False, random_state=rng)
1162+
floored[inds] += 1
1163+
need_to_add -= add_now
1164+
if need_to_add == 0:
1165+
break
1166+
return floored.astype(np.int)
1167+
1168+
11011169
class StratifiedShuffleSplit(BaseShuffleSplit):
11021170
"""Stratified ShuffleSplit cross-validator
11031171
@@ -1181,12 +1249,14 @@ def _iter_indices(self, X, y, labels=None):
11811249
(n_test, n_classes))
11821250

11831251
rng = check_random_state(self.random_state)
1184-
p_i = class_counts / float(n_samples)
1185-
n_i = np.round(n_train * p_i).astype(int)
1186-
t_i = np.minimum(class_counts - n_i,
1187-
np.round(n_test * p_i).astype(int))
11881252

11891253
for _ in range(self.n_splits):
1254+
# if there are ties in the class-counts, we want
1255+
# to make sure to break them anew in each iteration
1256+
n_i = _approximate_mode(class_counts, n_train, rng)
1257+
class_counts_remaining = class_counts - n_i
1258+
t_i = _approximate_mode(class_counts_remaining, n_test, rng)
1259+
11901260
train = []
11911261
test = []
11921262

@@ -1196,23 +1266,6 @@ def _iter_indices(self, X, y, labels=None):
11961266

11971267
train.extend(perm_indices_class_i[:n_i[i]])
11981268
test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
1199-
1200-
# Because of rounding issues (as n_train and n_test are not
1201-
# dividers of the number of elements per class), we may end
1202-
# up here with less samples in train and test than asked for.
1203-
if len(train) + len(test) < n_train + n_test:
1204-
# We complete by affecting randomly the missing indexes
1205-
missing_indices = np.where(bincount(train + test,
1206-
minlength=len(y)) == 0)[0]
1207-
missing_indices = rng.permutation(missing_indices)
1208-
n_missing_train = n_train - len(train)
1209-
n_missing_test = n_test - len(test)
1210-
1211-
if n_missing_train > 0:
1212-
train.extend(missing_indices[:n_missing_train])
1213-
if n_missing_test > 0:
1214-
test.extend(missing_indices[-n_missing_test:])
1215-
12161269
train = rng.permutation(train)
12171270
test = rng.permutation(test)
12181271

sklearn/model_selection/tests/test_split.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -535,17 +535,33 @@ def test_stratified_shuffle_split_init():
535535
StratifiedShuffleSplit(test_size=2).split(X, y))
536536

537537

538+
def test_stratified_shuffle_split_respects_test_size():
539+
y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
540+
test_size = 5
541+
train_size = 10
542+
sss = StratifiedShuffleSplit(6, test_size=test_size, train_size=train_size,
543+
random_state=0).split(np.ones(len(y)), y)
544+
for train, test in sss:
545+
assert_equal(len(train), train_size)
546+
assert_equal(len(test), test_size)
547+
548+
538549
def test_stratified_shuffle_split_iter():
539550
ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
540551
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
541-
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
552+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
542553
np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
543-
np.array([-1] * 800 + [1] * 50)
554+
np.array([-1] * 800 + [1] * 50),
555+
np.concatenate([[i] * (100 + i) for i in range(11)])
544556
]
545557

546558
for y in ys:
547559
sss = StratifiedShuffleSplit(6, test_size=0.33,
548560
random_state=0).split(np.ones(len(y)), y)
561+
# this is how test-size is computed internally
562+
# in _validate_shuffle_split
563+
test_size = np.ceil(0.33 * len(y))
564+
train_size = len(y) - test_size
549565
for train, test in sss:
550566
assert_array_equal(np.unique(y[train]), np.unique(y[test]))
551567
# Checks if folds keep classes proportions
@@ -556,7 +572,9 @@ def test_stratified_shuffle_split_iter():
556572
return_inverse=True)[1]) /
557573
float(len(y[test])))
558574
assert_array_almost_equal(p_train, p_test, 1)
559-
assert_equal(y[train].size + y[test].size, y.size)
575+
assert_equal(len(train) + len(test), y.size)
576+
assert_equal(len(train), train_size)
577+
assert_equal(len(test), test_size)
560578
assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
561579

562580

@@ -572,8 +590,8 @@ def assert_counts_are_ok(idx_counts, p):
572590
threshold = 0.05 / n_splits
573591
bf = stats.binom(n_splits, p)
574592
for count in idx_counts:
575-
p = bf.pmf(count)
576-
assert_true(p > threshold,
593+
prob = bf.pmf(count)
594+
assert_true(prob > threshold,
577595
"An index is not drawn with chance corresponding "
578596
"to even draws")
579597

@@ -593,9 +611,8 @@ def assert_counts_are_ok(idx_counts, p):
593611
counter[id] += 1
594612
assert_equal(n_splits_actual, n_splits)
595613

596-
n_train, n_test = _validate_shuffle_split(n_samples,
597-
test_size=1./n_folds,
598-
train_size=1.-(1./n_folds))
614+
n_train, n_test = _validate_shuffle_split(
615+
n_samples, test_size=1. / n_folds, train_size=1. - (1. / n_folds))
599616

600617
assert_equal(len(train), n_train)
601618
assert_equal(len(test), n_test)
@@ -656,7 +673,7 @@ def test_label_shuffle_split():
656673
for l in labels:
657674
X = y = np.ones(len(l))
658675
n_splits = 6
659-
test_size = 1./3
676+
test_size = 1. / 3
660677
slo = LabelShuffleSplit(n_splits, test_size=test_size, random_state=0)
661678

662679
# Make sure the repr works

sklearn/tests/test_cross_validation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -479,24 +479,30 @@ def test_stratified_shuffle_split_init():
479479
def test_stratified_shuffle_split_iter():
480480
ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
481481
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
482-
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
482+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
483483
np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
484484
np.array([-1] * 800 + [1] * 50)
485485
]
486486

487487
for y in ys:
488488
sss = cval.StratifiedShuffleSplit(y, 6, test_size=0.33,
489489
random_state=0)
490+
test_size = np.ceil(0.33 * len(y))
491+
train_size = len(y) - test_size
490492
for train, test in sss:
491493
assert_array_equal(np.unique(y[train]), np.unique(y[test]))
492494
# Checks if folds keep classes proportions
493-
p_train = (np.bincount(np.unique(y[train], return_inverse=True)[1])
494-
/ float(len(y[train])))
495-
p_test = (np.bincount(np.unique(y[test], return_inverse=True)[1])
496-
/ float(len(y[test])))
495+
p_train = (np.bincount(np.unique(y[train],
496+
return_inverse=True)[1]) /
497+
float(len(y[train])))
498+
p_test = (np.bincount(np.unique(y[test],
499+
return_inverse=True)[1]) /
500+
float(len(y[test])))
497501
assert_array_almost_equal(p_train, p_test, 1)
498-
assert_equal(y[train].size + y[test].size, y.size)
499-
assert_array_equal(np.intersect1d(train, test), [])
502+
assert_equal(len(train) + len(test), y.size)
503+
assert_equal(len(train), train_size)
504+
assert_equal(len(test), test_size)
505+
assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
500506

501507

502508
def test_stratified_shuffle_split_even():

0 commit comments

Comments
 (0)