-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
array API support for cosine_distances #29265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
array API support for cosine_distances #29265
Conversation
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is some feedback. Please also test that this works on cuda devices with PyTorch and CuPy using:
https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c
|
Merging main to hopefully get a green CI on this PR. |
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @EmilyXiny!
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @EmilyXinyi. A few comments otherwise looks good.
da921c4 to
12d5569
Compare
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @EmilyXinyi
Reference Issues/PRs
towards #26024
What does this implement/fix? Explain your changes.
array API support for
cosine_distancesAny other comments?
.clipis supported in the 2023.12 version but not the one that we are currently using, so I created a function insklearn.utils._array_apifor now.fill_diagonalis not supported and I don't see it being in the plan of being supported in the array-api repo, so I created an alternative._fill_diagonalinsklearn.metrics.pairwisePlease let me know if there are any changes I should make, thanks!