Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for train_test_split with Array API input #26855

Merged
merged 7 commits into from Aug 9, 2023

Conversation

betatim
Copy link
Member

@betatim betatim commented Jul 18, 2023

Reference Issues/PRs

(need to find one)

What does this implement/fix? Explain your changes.

This mostly adds some tests that use train_test_split with Array API input and compare to using a pure Numpy array as input.

Any other comments?

First attempt of seeing what happens when you feed cupy/pytorch/array api arrays to train_test_split. Need to explore more of the different parameters to see if they all "just work".

  • add a changelog entry

@github-actions
Copy link

github-actions bot commented Jul 18, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 998eb93. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that it works out of the box.

sklearn/model_selection/tests/test_split.py Show resolved Hide resolved
@betatim betatim marked this pull request as ready for review July 19, 2023 16:13
Comment on lines +189 to +190
def __eq__(self, other):
return self._namespace == other._namespace
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this? It is convenient in the test to be able to compare (wrapped) namespaces for equivalence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i prefer explicit namespace assertions 8n tests. We could have a helper to assert same namespace in tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with explicit? Getting a string representation that we can compare to the array_namespace passed in to the test?

In the test itself I use get_namespace(input)[0] == get_namespace(output)[0] to check that input and output are in the same namespace. This works when the namespace is one from the array compat library, but not for the few namespaces that we wrap in this wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with overriding __eq__ like this.

@betatim
Copy link
Member Author

betatim commented Jul 20, 2023

@thomasjpfan and @ogrisel - if you want to look at a PR that mostly adds new tests, this is one :D

sklearn/model_selection/tests/test_split.py Outdated Show resolved Hide resolved
Comment on lines +189 to +190
def __eq__(self, other):
return self._namespace == other._namespace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i prefer explicit namespace assertions 8n tests. We could have a helper to assert same namespace in tests.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

sklearn/model_selection/tests/test_split.py Show resolved Hide resolved
Comment on lines +189 to +190
def __eq__(self, other):
return self._namespace == other._namespace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with overriding __eq__ like this.

@betatim
Copy link
Member Author

betatim commented Aug 4, 2023

Should we list this kind of thing (functions, not estimators) in the "estimators with support" section of doc/modules/array_api.rst? New section?

@thomasjpfan
Copy link
Member

Should we list this kind of thing (functions, not estimators) in the "estimators with support" section of doc/modules/array_api.rst? New section?

I like a new section. I think it's good to keep track of all the Array API supported estimators & functions in array_api.rst.

@betatim
Copy link
Member Author

betatim commented Aug 8, 2023

What do you think of the current patch? I added subsections, one called Estimators and one called Tools.

@thomasjpfan thomasjpfan enabled auto-merge (squash) August 9, 2023 17:36
@thomasjpfan
Copy link
Member

What do you think of the current patch? I added subsections, one called Estimators and one called Tools.

That is okay with me.

@thomasjpfan thomasjpfan merged commit 1b0a51b into scikit-learn:main Aug 9, 2023
25 checks passed
@betatim betatim deleted the array_api_train_test_split branch August 10, 2023 13:04
TamaraAtanasoska pushed a commit to TamaraAtanasoska/scikit-learn that referenced this pull request Aug 21, 2023
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants