ENH add support for Array API to mean_pinball_loss and explained_variance_score#29978
Conversation
glemaitre
left a comment
There was a problem hiding this comment.
So I think that the functions relying on the percentile or _weighted_percentile function cannot be implemented for the moment.
394a527 to
cf41a3e
Compare
cf41a3e to
d2acd79
Compare
mean_pinball_loss and explained_variance_score
TL;DR: I've reverted the metrics. Thank you for catching that! I have resolved the dependence of I'll also add a note in the Array API meta issue to inform others to skip |
…ck_reg_targets_with_floating_dtype`.
|
Thanks for the updates @OmarManzoor! It's ready for a second review! |
OmarManzoor
left a comment
There was a problem hiding this comment.
@virchan Two of the CUDA tests are failing. I think we need to handle using the appropriate device. Could you kindly have a look?
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
|
Thank you for the follow-up, @OmarManzoor! I have updated the |
|
@virchan Were you able to test with the MPS device locally? |
OmarManzoor
left a comment
There was a problem hiding this comment.
LGTM. Thanks @virchan
I checked with MPS as well and it looks okay 🟢
|
Thank you everyone for your time! |
Reference Issues/PRs
Towards #26024.
What does this implement/fix? Explain your changes.
This PR adds Array API support for the following regression metrics:
(dependent onmedian_absolute_error_weighted_percentile)explained_variance_score(dependent ond2_pinball_scorepercentileand_weighted_percentile)(dependent ond2_absolute_error_scorepercentileand_weighted_percentile)mean_pinball_lossAny other comments?
cc: @ogrisel