Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open up check_array and BaseEstimator._validate_data to overriding xp.asarray with an additional callable parameter asarray_fn #25433

Open
fcharras opened this issue Jan 19, 2023 · 4 comments · May be fixed by #25434 or #25617

Comments

@fcharras
Copy link
Contributor

fcharras commented Jan 19, 2023

Describe the workflow you want to enable

Some people (including @betatim @ogrisel @jjerphan and I) have been devising a plugin system that would open up sklearn estimators to other external implementations, and in particular implementations with GPU backends - see #22438 .

Some of the plugins we're considering can materialize the data in memory with an array library that is compatible with the Array API - namely CuPy and dpctl.tensor.

One thing we've found is that internally those plugins can benefit from using directly BaseEstimator._validate_data and check_array from scikit-learn to do the data acceptation and preparation step.

Describe your proposed solution

To enable this it would be nice to be able to pass a asarray_fn to check_array and _validate_data, that would be called instead of xp.asarray in _asarray_with_order . This would enable the plugin to convert directly the input data to an array that the plugin supports (e.g. cupy or dpctl.tensor) while still benefiting from reusing existing validation code in check_array.

The override can be necessary in case the asarray method from the array library implements a superset of the array api that is necessary for the plugin, but is currently not used by check_array because it's not part of the array api (for instance, the order argument isn't passed to asarray for array libraries other than numpy)

@fcharras fcharras added Needs Triage Issue requires triage New Feature labels Jan 19, 2023
@fcharras fcharras changed the title Open up check_array and BaseEstimator._validate_data to overriding _asarray_with_order with an additional callable parameter asarray_fn Open up check_array and BaseEstimator._validate_data to overriding xp.asarray with an additional callable parameter asarray_fn Jan 19, 2023
@jjerphan jjerphan added the RFC label Jan 19, 2023
@fcharras fcharras linked a pull request Jan 19, 2023 that will close this issue
@fcharras
Copy link
Contributor Author

Example given in #25434

@ogrisel
Copy link
Member

ogrisel commented Jan 27, 2023

It was discussed at the triage meeting:

  • people are concerned with kwarg proliferation and the added maintenance burden but we don't see any alternative to avoid duplicating the code of check_array otherwise;
  • it was suggested to add this kwarg in a private version of _check_array while preserving the public check_array parameter list unchanged.

Alternatively, we could use the sklearn.experimental explicit activation of this feature.

@ogrisel ogrisel removed the Needs Triage Issue requires triage label Jan 27, 2023
@thomasjpfan
Copy link
Member

After working on Array API a little more, I propose adding a namespace xp parameter to check_array. For array libraries that do not define __array_namespace__, one can pass a namespace. The namespace does not have to be full compatible with the Array API specification. It only needs the functions required by check_array. Afterwards, we adjust _asarray_with_order to inspect the signature of asarray and if it accepts order, then pass it in.

This proposal should enable the Engine/Plugin API to use the Array API code paths for validation in check_array.

@fcharras
Copy link
Contributor Author

That sounds great. it would answer the usecases I have reported with dpctl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants