ENH Add Array API compatibility tests for *SearchCV classes#27096
ENH Add Array API compatibility tests for *SearchCV classes#27096ogrisel merged 14 commits intoscikit-learn:mainfrom
*SearchCV classes#27096Conversation
RandomizedSearchCV and GridSearchCV appear to just work with Array API inputs.
970c612 to
6f74abb
Compare
ogrisel
left a comment
There was a problem hiding this comment.
Nice! Can you please add those estimators to the list of estimators in the array api section of the user guide?
This is now the case. Let's see if this can work. |
|
I started some of the failures reported by our test suite after the merge with |
ogrisel
left a comment
There was a problem hiding this comment.
Marking this PR as not accepted because there is still work to do.
| train=None, | ||
| test=None, | ||
| train=train, | ||
| test=test, |
There was a problem hiding this comment.
Not sure why this did no fail in main but based on the docstrings _fit_and_score was never supposed to accept None for its train and test arguments so better pass valid integer arrays in this test instead.
| ) | ||
| search_cv = SearchCV(Estimator(), param_grid, cv=2, **extra_params) | ||
| search_cv = SearchCV( | ||
| Estimator(), param_grid, cv=2, error_score="raise", **extra_params |
There was a problem hiding this comment.
Setting error_score="raise" makes pytest traceback reading much more direct, especially in CI logs.
ogrisel
left a comment
There was a problem hiding this comment.
I think this is +1 on my side. The fact that the CV splitters are currently NumPy only makes the conversion for stratifification a bit ugly.
We probably need a follow-up PR to check that all CV splitters work when splitting array API inputs with train_test_split or other tools that accept array API inputs and cv= parameter but I would rather do that in a dedicated PR and focus this one on what is necessary for *SearchCV themselves.
/cc @OmarManzoor and @betatim.
|
BTW, I tested this PR on a cuda host and on an MPS host and all tests pass. |
betatim
left a comment
There was a problem hiding this comment.
LGTM.
Not sure I should merge it because I opened the PR. But voting to merge it anyway :D
| # we need the following explicit conversion: | ||
| xp, is_array_api = get_namespace(y) | ||
| if is_array_api: | ||
| y = _convert_to_numpy(y, xp) |
There was a problem hiding this comment.
Would it be possible to add a test to cover this?
There was a problem hiding this comment.
I think we had coverage for it via the test I removed in the last commit. But I thought that we tested the same thing via a common test. So I am a bit puzzled why that doesn't happen :-/
There was a problem hiding this comment.
I understand what's going on:
- the common tests run the search cv meta estimator using
LogisticRegressionandRidgeas base-estimator. - the
LogisticRegressiontest is skipped becauseLogisticRegressiondoes not have the array API estimator tag hence the wrapping search cv estimators ain't either; - the stratified k-fold CV splitter is only used for classification problem, hence the
Ridge-based common test does not cover it; - the previous hard-coded test of this PR used
LinearDiscriminantAnalysiswhich is a classifier and supports array API, hence could cover this line.
Since adding array API support to LogisticRegression might take a bit of time, I would be in favor of re-adding the previous non-common test that was removed from this PR.
There was a problem hiding this comment.
Yeah that seems right. Thanks for the explanation! I think the test seems valid for now and not redundant.
There was a problem hiding this comment.
@ogrisel I added the test back. Can you have a look and then we can probably merge.
|
BTW, I tested on CUDA with google colab and https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c and tests are green. |
|
I marked this PR for auto-merge. Thanks all! |
RandomizedSearchCV and GridSearchCV appear to just work with Array API inputs.
This adds a test that makes sure that they will keep working.
For the common tests to pass we need
Ridgeto support the Array API.