-
Notifications
You must be signed in to change notification settings - Fork 183
[enhancement] simplify array_api enabling tags via wrapper #2566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[enhancement] simplify array_api enabling tags via wrapper #2566
Conversation
Codecov ReportAttention: Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
/intelci: run |
After the latest merges, is there something blocking this PR from applying this function to the current classes that can work with array API? |
You are right on point with that comment. Unfortunately none of the current estimators are fully array api compliant (realized I didn't include it while writing the first part of the docs), but a deluge of estimators should be ready for review using it in <1 month. |
/intelci: run |
@icfaust Then I think it'd be better to hold off this PR until there are estimators ready to apply and test in the same PR. |
Maybe a stupid question, but wouldn't it require checking the version of sklearn (1.3, 1.6) further in the code of each estimator to use the proper API? |
That is a good question, this will take care of all issues related to array API support at class definition time (import time), meaning those additional checks throughout the codebase are all taken care of here: https://github.com/uxlfoundation/scikit-learn-intelex/blob/main/sklearnex/_device_offload.py#L118 in combination with https://github.com/uxlfoundation/scikit-learn-intelex/blob/main/sklearnex/_utils.py#L31-L43 . These will hide all of the sklearn version oddities related to it from the developer and user and are all done at import time, and makes all of the interfaces look like latest release sklearn with respect to tag interfaces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds a decorator to automatically attach the right tags for array API support to oneDALEstimator subclasses, reducing boilerplate across estimators.
- Imports
oneDALEstimator
andTags
to build the wrapper - Adds
enable_array_api
function that sets__sklearn_tags__
for sklearn ≥1.6 or_more_tags
for sklearn ≥1.3 - Returns the modified class so users can apply
@enable_array_api
on estimators
Comments suppressed due to low confidence (2)
sklearnex/utils/_array_api.py:90
- No unit tests were added to verify that
enable_array_api
correctly attaches__sklearn_tags__
and_more_tags
under the intended sklearn versions. Consider adding tests for both code paths.
def enable_array_api(original_class: type[oneDALEstimator]) -> type[oneDALEstimator]:
sklearnex/utils/_array_api.py:117
- The return type annotation
dict[bool]
is invalid. It should specify key and value types, e.g.,dict[str, bool]
.
def _more_tags(self) -> dict[bool]:
Description
Add
enable_array_api
wrapper for classes. This will set the necessary method (__sklearn_tags__
for sklearn>=1.6 or_more_tags
for sklearn > 1.3) to signify to the sklearnex infrastructure thatarray_api_dispatch
would work with onedal for that estimator when enabled. Otherwise this code would have to be manually added to each estimator, this will simplify development.No performance benchmarks necessary as it is an import time cost.
PR should start as a draft, then move to ready for review state after CI is passed and all applicable checkboxes are closed.
This approach ensures that reviewers don't spend extra time asking for regular requirements.
You can remove a checkbox as not applicable only if it doesn't relate to this PR in any way.
For example, PR with docs update doesn't require checkboxes for performance while PR with any change in actual code should have checkboxes and justify how this code change is expected to affect performance (or justification should be self-evident).
Checklist to comply with before moving PR from draft:
PR completeness and readability
Testing
Performance