Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting #7593

Merged
merged 18 commits into from Nov 3, 2016

Conversation

Projects
None yet
4 participants
@raghavrv
Copy link
Member

raghavrv commented Oct 6, 2016

Fixes #7582 and #7126

At sklearn 0.18.0

>>> from sklearn.model_selection import train_test_split
>>> X, y = [[1,], [2,], [3,], [4,], [5,], [6,]], ['1', '2', '1', '2', '1', '2']
>>> _ = train_test_split(X, y, stratify=y)
IndexError: index 0 is out of bounds for axis 1 with size 0

That is fixed after this PR.

This PR also cleans up some docstrings and adds test for LeavePGroupsOut and LeaveOneGroupOut...

@jnothman @amueller @lesteve Reviews please :)

@@ -1708,3 +1714,23 @@ def _build_repr(self):
params[key] = value

return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name)))


def _check_X_y_groups(X, y, groups):

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 6, 2016

Author Member

Should this reside inside utils.validation?

This comment has been minimized.

Copy link
@amueller

amueller Oct 7, 2016

Member

probably. Is the same applicable for sample_weights? What do we usually do with sample_weights?
We might just write check_X_y and then do a check_consistent_length(X, sample_weights) and check_array(sample_weights).

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from d45b75c to ff5f379 Oct 6, 2016

allow_nd=True)
check_consistent_length(X, y)
if groups is not None:
groups = check_array(groups, accept_sparse=['coo', 'csr', 'csc'],

This comment has been minimized.

Copy link
@amueller

amueller Oct 7, 2016

Member

groups can be infinite? and sparse? and nd? Is that tested? ;)

This comment has been minimized.

Copy link
@jnothman

jnothman Oct 8, 2016

Member

cannot be sparse, surely.

This comment has been minimized.

Copy link
@jnothman

jnothman Oct 8, 2016

Member

or nd

dtype=None, force_all_finite=False, ensure_2d=False,
allow_nd=True)
if y is not None:
y = check_array(y, accept_sparse=['coo', 'csr', 'csc'],

This comment has been minimized.

Copy link
@amueller

amueller Oct 7, 2016

Member

Same for y. Are these tested? Should they be? I guess we should be as loose as possible with the test as long as the cross-validation classes work.

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 8, 2016

There is a test for train_test_split which tests support for nd arrays... And we cannot allow nd only there as it uses ShuffleSplit internally... You are correct, groups cannot have nan or be nd but they can be sparse I think...

And we could do the check_X_y followed by checks for groups, but it doesnt allow a None for y

@raghavrv raghavrv removed the Needs Review label Oct 9, 2016

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from 76027d5 to 578442b Oct 12, 2016

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 12, 2016

Okay, I did away with the helper and made a case to case minimial validation for y and groups. For X, indexability is alone checked. One more pass @jnothman @amueller please!

@@ -843,6 +877,20 @@ def test_shufflesplit_reproducible():
list(a for a, b in ss.split(X)))


def test_shufflesplit_list_input():
# Check that when y is a list / list of string labels, it works.
ss = ShuffleSplit(random_state=42)

This comment has been minimized.

Copy link
@amueller

amueller Oct 13, 2016

Member

shouldn't that be StratifiedShuffleSplit?

@@ -1087,6 +1091,8 @@ def __init__(self, n_splits=5, test_size=0.2, train_size=None,
def _iter_indices(self, X, y, groups):
if groups is None:
raise ValueError("The groups parameter should not be None")
groups = check_array(groups, ensure_2d=False, dtype=None)

This comment has been minimized.

Copy link
@amueller

amueller Oct 13, 2016

Member

How about GroupKFold, LeaveOneGroupOut, LeavePGroupsOut?

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 16, 2016

Author Member

Fixed... Thanks for the catch!!

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 16, 2016

I fixed #7126 along the way... One more look at this @amueller @jnothman

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 16, 2016

Argh. There seemed to have been no tests for LeavePGroupsOut and LeaveOneGroupOut in the old/new tests... Have added them too...

@RPGOne

RPGOne approved these changes Oct 17, 2016

@RPGOne

RPGOne approved these changes Oct 17, 2016

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from f117a07 to 13f1e95 Oct 17, 2016

@@ -891,6 +901,8 @@ def get_n_splits(self, X, y, groups):
"""
if groups is None:
raise ValueError("The groups parameter should not be None")
X, y, groups = indexable(X, y, groups)
groups = check_array(groups, ensure_2d=False, dtype=None)

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

I'd to it the other way around, I think.

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 17, 2016

Author Member

check_array followed by indexable?

This comment has been minimized.

Copy link
@amueller
@amueller
Copy link
Member

amueller left a comment

Looks good apart from some nitpicks.


for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1),
(lpgo_2, 2))):
groups = (np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

do we want these to be file-level constants?

logo = LeaveOneGroupOut()
lpgo_1 = LeavePGroupsOut(n_groups=1)
lpgo_2 = LeavePGroupsOut(n_groups=2)
lpgo_3 = LeavePGroupsOut(n_groups=3)

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

for this one you only test the repr, right?

[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'])

all_n_splits = np.array([[3, 3, 3],

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

why do you hard-code it like this? that seems hard to validate. It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right?

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 19, 2016

Author Member

It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right

That is the implementation in _split.py. I thought it would be better to compare it against hand calculated values?

This comment has been minimized.

Copy link
@amueller

amueller Oct 19, 2016

Member

Hm. The correctness of your "hand calculated values" is not immediately obvious to me.
How about

n_groups = len(np.unique(groups_i))
n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2 ?

but I'm also fine leaving it like it is.
Why is all_n_splits of length 7 when groups is of length 6? (or github shows me a weird diff)

# First test: no train group is in the test set and vice versa
grps_train_unique = np.unique(groups_arr[train])
grps_test_unique = np.unique(groups_arr[test])
assert_false(np.any(np.in1d(groups_arr[train],

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

why not test the intersection is empty?
assert_equal(set(groups_arr[train]).intersection(groups_arr[test]), set())

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

(or intersect1d if you prefer)

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 19, 2016

Author Member

Sure. Thanks

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 20, 2016

Author Member

Wait that is already done in the next 2 lines...

This comment has been minimized.

Copy link
@raghavrv

raghavrv Oct 20, 2016

Author Member

("third test")

This comment has been minimized.

Copy link
@amueller

amueller Oct 20, 2016

Member

third tests checks whether indices are disjoint, my code checks if the groups are disjoint.

grps_train_unique)))

# Second test: train and test add up to all the data
assert_equal(groups_arr[train].size +

This comment has been minimized.

Copy link
@amueller

amueller Oct 17, 2016

Member

len(train) + len(test) = len(groups)?

@amueller amueller changed the title [MRG] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting Oct 19, 2016

@amueller

This comment has been minimized.

Copy link
Member

amueller commented Oct 19, 2016

lgtm apart from minor comments

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from 13f1e95 to fce36af Oct 20, 2016

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 20, 2016

Have addressed your comments. A 2nd look please? @jnothman @vene @TomDLT ?

@raghavrv raghavrv changed the title [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 2] FIX Validate and convert X, y and groups to ndarray before splitting Oct 20, 2016

@raghavrv raghavrv changed the title [MRG + 2] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting Oct 20, 2016

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from b5d1fe3 to 44f6db6 Oct 24, 2016

@amueller

This comment has been minimized.

Copy link
Member

amueller commented Oct 24, 2016

travis fails?

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Oct 24, 2016

Sorry about that. Should be fixed now...

@amueller amueller added the Blocker label Oct 25, 2016

@raghavrv raghavrv force-pushed the raghavrv:check_X_y_groups branch from 0516776 to 1ca13d1 Nov 3, 2016

np.testing.assert_equal(y_train2, y_train3)
np.testing.assert_equal(X_test1, X_test3)
np.testing.assert_equal(y_test3, y_test2)
for stratify in ((y1, y2, y3), (None, None, None)):

This comment has been minimized.

Copy link
@raghavrv

raghavrv Nov 3, 2016

Author Member

Does this seem okay? @jnothman @amueller

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Nov 3, 2016

Apologies for the delay! Have rebased and added the test... Could you check if it's okay?


for stratify in ((y1, y2, y3), (None, None, None)):
X_train1, X_test1, y_train1, y_test1 = train_test_split(
X, y1, stratify=stratify[0], random_state=0)

This comment has been minimized.

Copy link
@jnothman

jnothman Nov 3, 2016

Member

I think stratify=y1 if stratify else None would be more readable (where stratify in (True, False) is iterated)

This comment has been minimized.

Copy link
@raghavrv

raghavrv Nov 3, 2016

Author Member

Done :)

@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Nov 3, 2016

(Maybe we should allow stratify to be an int index into the **args)

@jnothman jnothman merged commit d7c956a into scikit-learn:master Nov 3, 2016

3 checks passed

ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details

@raghavrv raghavrv deleted the raghavrv:check_X_y_groups branch Nov 3, 2016

@raghavrv

This comment has been minimized.

Copy link
Member Author

raghavrv commented Nov 3, 2016

Thanks for the patient review and merge!

amueller added a commit to amueller/scikit-learn that referenced this pull request Nov 9, 2016

@amueller

This comment has been minimized.

Copy link
Member

amueller commented Nov 14, 2016

needs a whatsnew maybe?

sergeyf added a commit to sergeyf/scikit-learn that referenced this pull request Feb 28, 2017

afiodorov added a commit to unravelin/scikit-learn that referenced this pull request Apr 25, 2017

Sundrique added a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017

paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017

maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017

MLopez-Ibanez pushed a commit to MLopez-Ibanez/scikit-learn that referenced this pull request Feb 9, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.