Fix array api integration for additive_chi2_kernel with torch mps#29256
Fix array api integration for additive_chi2_kernel with torch mps#29256OmarManzoor wants to merge 1 commit intoscikit-learn:mainfrom
Conversation
|
Oh, I overlooked it has to be on the same device. Maybe common tests should also assert it? It should spot the error at least in gpu workflows |
I think this might be specifically an issue with mps because it seems to work fine with cuda. |
EdAbati
left a comment
There was a problem hiding this comment.
I confirm that this fix is needed for the mps device 😕
| from ..utils._array_api import ( | ||
| _find_matching_floating_dtype, | ||
| _is_numpy_namespace, | ||
| device, |
There was a problem hiding this comment.
I think we could also use get_namespace_and_device and substitute it to xp, _ = get_namespace(X, Y) below
|
FYI I have updated to the latest |
Nice that they did this. It was verbose to have to specify the device even for scalar arrays. I am -0 for merging this PR. Since array API is very new, one can expect people who want to try it to use the latest stable versions of libraries and I would rather keep our code based not too verbose if not necessary. What other people think? We could make that requirement explicit in the |
I think it makes sense if it is fixed. No need for the change. @ogrisel should we close this PR? |
|
Let's wait a bit to let others have time to express their opinion. |
|
Let's close, we can reopen if others disagree. |
|
Note: we would have to check with other libraries. For instance I think jax support in |
Reference Issues/PRs
Follow up of #29144
What does this implement/fix? Explain your changes.
Any other comments?