Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StratifiedShuffleSplit generates overlapping train and test indices #6121

Closed
XuesongYang opened this issue Jan 6, 2016 · 6 comments
Closed

Comments

@XuesongYang
Copy link

Why there is overlap between dev_idx and t_idx in the following code? It should have been no overlap.

```
train_test_split = StratifiedShuffleSplit(labels, n_iter=1, test_size=0.2, random_state=0)
for train_idx, test_idx in train_test_split:
    train_tmp = set(train_idx)
    test_tmp = set(test_idx)
    assert_equal(train_tmp.intersection(test_tmp), set())
    X_train = np.copy(feats[train_idx])
    y_train = np.copy(labels[train_idx])
    trans_train = np.copy(trans[train_idx])
    X_valid = np.copy(feats[test_idx])
    y_valid = np.copy(labels[test_idx])
    trans_valid = np.copy(trans[test_idx])
del feats
del labels
del trans
dev_test_split = StratifiedShuffleSplit(y_valid, n_iter=1, test_size=0.5, random_state=0)
for dev_idx, t_idx in dev_test_split:
    dev_tmp = set(dev_idx)
    t_tmp = set(t_idx)
    assert_equal(dev_tmp.intersection(t_tmp), set())
    X_dev = np.copy(X_valid[dev_idx])
    y_dev = np.copy(y_valid[dev_idx])
    trans_dev = np.copy(trans_valid[dev_idx])
    X_test = np.copy(X_valid[t_idx])
    y_test = np.copy(y_valid[t_idx])
    trans_test = np.copy(trans_valid[t_idx])
del X_valid
del y_valid
del trans_valid
```

The second assert_equal() test prompted a error as follows:

    assert_equal(dev_tmp.intersection(t_tmp), set())
  File "/home/xyang45/miniconda2/lib/python2.7/unittest/case.py", line 513, in assertEqual
    assertion_func(first, second, msg=msg)
  File "/home/xyang45/miniconda2/lib/python2.7/unittest/case.py", line 796, in assertSetEqual
    self.fail(self._formatMessage(msg, standardMsg))
  File "/home/xyang45/miniconda2/lib/python2.7/unittest/case.py", line 410, in fail
    raise self.failureException(msg)
AssertionError: Items in the first set but not the second:
1160
1161
907
1070
1747
2232
@lesteve
Copy link
Member

lesteve commented Jan 6, 2016

Could you put together a standalone example, so that we can try to reproduce the problem ?

@XuesongYang
Copy link
Author

Please check with the reproducible example. Thanks.

bugs_sklearn.zip

@lesteve
Copy link
Member

lesteve commented Jan 7, 2016

I can reproduce it, here is a stand-alone snippet that reproduces the problem:

from sklearn.cross_validation import StratifiedShuffleSplit
from numpy.testing import assert_array_equal
import numpy as np

rng = np.random.RandomState(0)
labels = rng.randint(low=0, high=10, size=100)
sss = StratifiedShuffleSplit(labels, n_iter=1,
                             test_size=0.5, random_state=0)

train, test = next(iter(sss))

assert_array_equal(np.intersect1d(train, test), [])

The output:

AssertionError: 
Arrays are not equal

(shapes (1,), (0,) mismatch)
 x: array([89])
 y: array([], dtype=float64)

@lesteve
Copy link
Member

lesteve commented Jan 7, 2016

Also I tested this issue happens on master, 0.17 and 0.16. I didn't bother to check older versions.

@lesteve
Copy link
Member

lesteve commented Jan 7, 2016

@MagicYoung can you tweak the title so that it is self-explanatory, e.g. something like StratifiedShuffleSplit generates overlapping train and test indices ?

@XuesongYang XuesongYang changed the title the splits of data set has overlap! StratifiedShuffleSplit generates overlapping train and test indices Jan 7, 2016
@XuesongYang
Copy link
Author

I did not check with the implementation of StratifiedShuffleSplit, but I guess the issue is relevant to the sample distribution of the array.

@lesteve Title has already been changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants