Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add Array API compatibility to KernelCenterer #27556

Merged
merged 6 commits into from Oct 23, 2023

Conversation

EdAbati
Copy link
Contributor

@EdAbati EdAbati commented Oct 9, 2023

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

It makes the KernelCenterer implementation compatible and tested with the Array API.

@github-actions
Copy link

github-actions bot commented Oct 9, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 2d03e43. Link to the linter CI: here


K_pred_cols = (np.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, np.newaxis]
K_pred_cols = (xp.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried with xp.newaxis initially, but I got an AttributeError both with torch and numpy

AttributeError: module 'array_api_compat.torch' has no attribute 'newaxis'
AttributeError: module 'numpy.array_api' has no attribute 'newaxis'

@betatim do you know why newaxis may not be included?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't know why newaxis doesn't exist for them. I would have guessed "bug in array_api_compat" if that was the only place where its missing, but I think ofnumpy.array_api as the "very much exactly what the standard says" implementation. One option is that it is something that only got added in a newer version of the standard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened data-apis/array-api-compat#64 upstream.

Ok for using None in scikit-learn in the mean time.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along with #27642 I get the MPS tests to pass for this PR on my local macOS laptop.

I also ran the tests with pytorch and cupy on a CUDA host and all array_api tests pass for the KernelCenterer class.

@betatim betatim merged commit 1dd395c into scikit-learn:main Oct 23, 2023
28 checks passed
@EdAbati EdAbati deleted the KernelCenterer-array-api branch October 23, 2023 16:14
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Oct 31, 2023
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants