New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Add Array API compatibility to KernelCenterer
#27556
Conversation
|
||
K_pred_cols = (np.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, np.newaxis] | ||
K_pred_cols = (xp.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be replaced with xp.newaxis for readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried with xp.newaxis
initially, but I got an AttributeError
both with torch
and numpy
AttributeError: module 'array_api_compat.torch' has no attribute 'newaxis'
AttributeError: module 'numpy.array_api' has no attribute 'newaxis'
@betatim do you know why newaxis
may not be included?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I don't know why newaxis
doesn't exist for them. I would have guessed "bug in array_api_compat" if that was the only place where its missing, but I think ofnumpy.array_api
as the "very much exactly what the standard says" implementation. One option is that it is something that only got added in a newer version of the standard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened data-apis/array-api-compat#64 upstream.
Ok for using None
in scikit-learn in the mean time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along with #27642 I get the MPS tests to pass for this PR on my local macOS laptop.
I also ran the tests with pytorch and cupy on a CUDA host and all array_api tests pass for the KernelCenterer
class.
Reference Issues/PRs
Towards #26024
What does this implement/fix? Explain your changes.
It makes the
KernelCenterer
implementation compatible and tested with the Array API.