Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add Dask Array API support #28588

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Mar 7, 2024

Reference Issues/PRs

#26724

What does this implement/fix? Explain your changes.

Any other comments?

This depends on unmerged/unreleased changes in array-api-compat

Copy link

github-actions bot commented Mar 7, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: dbb8673. Link to the linter CI: here

sklearn/discriminant_analysis.py Outdated Show resolved Hide resolved
@@ -113,6 +113,14 @@ def _class_means(X, y):
"""
xp, is_array_api_compliant = get_namespace(X)
classes, y = xp.unique_inverse(y)
# Force lazy array api backends to call compute
if hasattr(classes, "persist"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on extracting the calls here to a compute function to realize lazy arrays.

Not sure if it will live in scikit-learn, though. I think I'd want it to live in array-api-compat, but I haven't discussed this with the folks there yet.

Based on usage here, the new compute function could have an option to e.g. compute shape only if that's what's needed, or compute the full array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this kind of code needs to go somewhere else. Having the "lazyness" leak through is a bit silly. I'm not sure where it should be though.

I don't know how to resolve this issue. In the past I've suggested that accessing something like .shape should just trigger the computation for you or that everything needs to grow a .compute, even eager libs. Both would allow array consumers like scikit-learn to write code that does not care whether the implementation is lazy or not. However there doesn't seem to be much support for this within the Array API community. The alternative is having to place code that checks "is this lazy? if yes trigger compute" in consumers like scikit-learn which I think is not great.

So this is maybe not a task for you alone to solve @lithomas1 but we need to find some kind of solution

Copy link
Contributor Author

@lithomas1 lithomas1 Mar 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this some more, would scikit-learn be happy with a shape helper (as opposed to my earlier suggestion of a compute helper) from array-api-compat?

This would handle the laziness of the various array libraries out there and call compute (or whatever the equivalent is) on the array if the shape needs to be materialized, before returning the shape of the array.

(In the event that array-api-compat is not installed, we can just define this to return, e.g. x.shape like scikit-learn already does)

The advantage of this approach would be that scikit-learn doesn't have to think about laziness anymore - that work would be outsourced to array-api-compat.

The only downsides of this approach of this approach are that

  • We indiscriminately materialize (for lazy arrays) even if it's strictly not necessary (e.g. we just access the shape to pass it to an array constructor like np.zeros). I don't think we'll lose too much (if any) performance here, though.
  • scikit-learn devs need to remember that accessing .shape is banned, and existing usages have to be migrated.
    • This can be mitigated with a pre-commit hook, to automatically detect .shape accesses.
      I think with something like ruff, one can write a custom rule to automatically rewrite .shape accesses to shape(x)
    • For existing usages, this is something that can be done as part of the Array API migration process, so it shouldn't cause too much churn on its own

How does this sound to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betatim any thoughts on the above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replying here and linking @ogrisel's comment #28588 (comment)

I think having a helper like shape() is the only(?) way forward for now. I'd not add it to array-api-compat but instead add it to scikit-learn in utils/_array_api.py - we already have a few helpers there for things that feel "scikit-learn specific".

More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.

Copy link
Member

@ogrisel ogrisel Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that a compute helper could also help deal with the boolean-masked assignment problem in r2_score described in more details in this comment: #28588 (comment)

This also a lazy evaluation problem but not related to shape values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having a helper like shape() is the only(?) way forward for now. I'd not add it to array-api-compat but instead add it to scikit-learn in utils/_array_api.py - we already have a few helpers there for things that feel "scikit-learn specific".

We could also do both: have a public helper in array-api-compat and a private scikit-learn specific helper in scikit-learn, that does nothing for libraries that are not accessed via array-api-compat (e.g. array-api-strict) as long the spec does not provide a standard way to deal with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.

Not sure how feasible this is and whether its desirable or not to trigger computation implicitly when using lazy libraries.

# compute right now
# Probably a dask bug
# (the error is also kinda flaky)
y = y.compute()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To shed more light on what this is, I think this bug happens when scikit-learn calls unique_inverse.

I think something is going wrong in dask somewhere where the results of intermediate operations are getting corrupted.

When the error occurs in computation is somewhat flaky, but it happens more often than not without the compute here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you investigate the cause of the corruption? It might be worth reporting a minimal reproducer upstream.

sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
@lithomas1
Copy link
Contributor Author

While LDA was a bit hard to port over to dask, PCA worked perfectly out of the box!
(which I think is a pretty big win for the array API)

The other preprocessing/metrics/tools also worked out of the box, I think (at least judging by the tests).

@lithomas1
Copy link
Contributor Author

lithomas1 commented Mar 7, 2024

R.e. performance

Testing out of core

for dask generated by dask-ml with parameters

n_samples=100_000
n_classes=2
n_informative=5

on a Gitpod machine with 2 cores and 8 GB RAM, I get

14m51s for 100,000,000 by 20 LDA (14.90 GB) # 100,000,000 samples
47 seconds for 10,000,000 by 20 LDA (1.49 GB) # 10,000,000 samples
6m 3.5s for 50,000,000 by 20 LDA (7.45 GB). # 50,000,000 samples

which I think is a pretty decent scaling

Distributed computation

For

n_samples=20,000,000
n_classes=2
n_informative=5

and chunksize=100,000

I am measuring
45s runtime for 4 workers (2 CPU, 8GB RAM), and
4 min 29s runtime for a single worker

  • note: there was a bit of spilling on this one.

@lithomas1 lithomas1 changed the title Add Dask Array API support to LDA/PCA ENH: Add Dask Array API support to LDA/PCA Mar 7, 2024
@adrinjalali
Copy link
Member

cc @betatim @ogrisel

@lithomas1
Copy link
Contributor Author

FYI, array-api-compat fixes are ongoing here data-apis/array-api-compat#110

@lithomas1 lithomas1 marked this pull request as ready for review March 24, 2024 21:41
@lithomas1
Copy link
Contributor Author

Since array-api-compat 1.5.1 came out and CI is green here,

I'm going to be marking this PR as ready for review.

The only other change I'm planning right now is splitting out the LDA changes, since that requires a patch to dask itself.

The correct way to handle laziness is also something that might be good to think about.
(It might be good to loop in more scikit-learn devs about this).

@lithomas1 lithomas1 changed the title ENH: Add Dask Array API support to LDA/PCA ENH: Add Dask Array API support Mar 26, 2024
@lithomas1
Copy link
Contributor Author

Here is another pass of feedback, possibly final on my end.

Would it be possible to open issues upstream, either in the dask, array-api-compat repos or the the Array API spec repo to link to from the comments of the skipped tests?

Yep, I'm tracking failures in #26724 (comment)

Failures are listed as checkboxes (with accompanying upstream issues if necessary).

Thanks for the reviews.
I've updated the PR following the feedback.

doc/modules/array_api.rst Outdated Show resolved Hide resolved
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the CI is green, LGTM for a first iteration.

What are you thoughts @betatim? Are you comfortable with merging partial dask support with skipped tests? Or do you prefer to implement a more complete support first?

@@ -47,6 +47,12 @@ See :ref:`array_api` for more details.

**Classes:**

**Libraries:**

- ``dask.array`` is now experimentally supported as an array API backend.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``dask.array`` is now experimentally supported as an array API backend.
- ``dask.array`` is now experimentally supported as an Array API backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like Olivier commented on a few of these before as well. I won't comment all that I find, but it would be great to have "Array API" every where.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not to be pedantic, but I would recommend not capitalising this as per data-apis/array-api#778 !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I was not aware of this. I guess we have a lot of doc/comments to update...

Let's fix that in a dedicated PR.

@betatim
Copy link
Member

betatim commented Apr 10, 2024

Are you comfortable with merging partial dask support with skipped tests? Or do you prefer to implement a more complete support first?

Unsure. I started looked at the diff just now. I think if some estimators work and some don't that is one thing. But if for an estimator only some of it works, that is a bit weird. From a user's perspective it would be weird I think because you can't really do much if only part of an estimator works.

Comment on lines +333 to +334
"Estimator/method does not work because of dask array API compliance"
" issues"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Estimator/method does not work because of dask array API compliance"
" issues"
"Estimator/method does not work because of missing dask Array API compliance."

WDYT?

@@ -1736,6 +1737,11 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
def check_array_api_metric(
metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight
):
if _array_api_skips.get(metric.__name__, {}).get(array_namespace) == "all":
pytest.skip(
f"{array_namespace} is not Array API compliant for {metric.__name__}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"{array_namespace} is not Array API compliant for {metric.__name__}"
f"Skipping {metric.__name__} because of missing Array API compliance in {array_namespace}"

It feels like the current text is backwards. What do you think of this change? Trying to make it clear that "we are skipping testing for X because something in {array_namespace}'s Array API support is missing" - which is what I think we are doing.

@@ -74,6 +93,13 @@ def _check_array_api_dispatch(array_api_dispatch):
"array_api_compat is required to dispatch arrays using the API"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"array_api_compat is required to dispatch arrays using the API"
"array-api-compat is required to dispatch arrays using the API"

Trying to be consistent with the spelling below. The package on PyPI is called "array-api-compat" so I think it makes sense to use that

Comment on lines -951 to +967
if method is None:
continue
if method is None or method_name in methods_to_skip:
raise SkipTest(
f"{array_namespace} is not Array API compliant for method"
f" {method_name} of {name}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change for method is None? This is the case when an estimator doesn't have a particular method, in which case we aren't skipping a test (and certainly not because of missing Array API support :D)

@@ -1019,6 +1034,7 @@ def check_array_api_input_and_values(
array_namespace,
device=None,
dtype_name="float64",
skip_methods={},
Copy link
Member

@betatim betatim Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use skip_methods=None + if skip_methods is None: skip_methods = {} in the function body. Mutable values as default value for a function argument is a trap waiting for someone to fall into. Right now I think this doesn't cause a bug, but someone from the future might change something without realising that we set this trap for them.

Details
def foo(x, a=[]):
  a.append(x)
  print(a)

foo(1)  # prints [1]
foo(1)  # print [1, 1]

@betatim
Copy link
Member

betatim commented Apr 10, 2024

Can we use array-api-compat to add slogdet to dask arrays? Then we'd avoid the problem that an estimator semi works.

For the two other skips (LinearDiscriminantAnalysis and r2_score), what is the error/exception a user will see when they try to use them with a dask array? Should we have a dedicated exception message for this case or is "random error", like you get when you use an estimator that doesn't yet have Array API support, good enough?

@ogrisel
Copy link
Member

ogrisel commented Apr 10, 2024

Should we have a dedicated exception message for this case or is "random error", like you get when you use an estimator that doesn't yet have Array API support, good enough?

This is going to be hard (impossible) to do that in a library agnostic way. Note that we already get low-level error message with dask array in main:

>>> import sklearn
>>> sklearn.set_config(array_api_dispatch=True)
>>> import dask.array as da
>>> X = da.random.normal(size=(1000, 10))
>>> y = da.random.randint(0, 2, size=1000)
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
>>> LinearDiscriminantAnalysis().fit(X, y)
Traceback (most recent call last):
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/backends.py:140 in wrapper
    return func(*args, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/array/wrap.py:63 in wrap_func_shape_as_first_arg
    parsed = _parse_wrap_args(func, args, kwargs, shape)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/array/wrap.py:33 in _parse_wrap_args
    chunks = normalize_chunks(chunks, shape, dtype=dtype)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/array/core.py:3095 in normalize_chunks
    chunks = auto_chunks(chunks, shape, limit, dtype, previous_chunks)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/array/core.py:3210 in auto_chunks
    raise ValueError(
ValueError: Can not perform automatic rechunking with unknown (nan) chunk sizes.

A possible solution: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
Summary: to compute chunks sizes, use

   x.compute_chunk_sizes()  # for Dask Array `x`
   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[11], line 1
    LinearDiscriminantAnalysis().fit(X, y)
  File ~/code/scikit-learn/sklearn/base.py:1474 in wrapper
    return fit_method(estimator, *args, **kwargs)
  File ~/code/scikit-learn/sklearn/discriminant_analysis.py:638 in fit
    self._solve_svd(X, y)
  File ~/code/scikit-learn/sklearn/discriminant_analysis.py:510 in _solve_svd
    self.means_ = _class_means(X, y)
  File ~/code/scikit-learn/sklearn/discriminant_analysis.py:116 in _class_means
    means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:134 in zeros
    return xp.zeros(shape, dtype=dtype, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/dask/backends.py:142 in wrapper
    raise type(e)(
ValueError: An error occurred while calling the wrap_func_shape_as_first_arg method registered to the numpy backend.
Original Message: Can not perform automatic rechunking with unknown (nan) chunk sizes.

A possible solution: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
Summary: to compute chunks sizes, use

   x.compute_chunk_sizes()  # for Dask Array `x`
   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`

So from this point of view, this PR is not making things worse than they currently are. At least we document that dask support is not fully functional yet.

@lithomas1
Copy link
Contributor Author

Can we use array-api-compat to add slogdet to dask arrays? Then we'd avoid the problem that an estimator semi works.

This is not possible yet.

Dask doesn't expose a way to do a determinant yet.
(and there doesn't seem to be an easy workaround, since dask's LU/QR decomps don't provide enough information to do it. I've opened an issue on the dask side, but there doesn't seem to be a response yet dask/dask#11042)

I agree that it's not ideal to have dask not be able to use score, but this is something that I think is reasonable to have users work around for now, e.g. by using _estimator_with_converted_arrays, to convert the estimator from dask to numpy arrays.
(similar to how we transfer arrays from GPU to CPU on cupy)

I'm not sure how important score is though. There seem to be some usages of it, but curiously there seem to be no examples in scikit-learn itself on using score.

https://github.com/search?q=repo%3Ascikit-learn%2Fscikit-learn+pca.score&type=code

@betatim
Copy link
Member

betatim commented Apr 22, 2024

My feeling overall is that adding dask support leaves us with something that is only half finished for the (very few) existing estimators. My first reaction to that is to wait a bit until Array API support in dask has progressed a bit more. My assumption is that there is no need to be in a hurry here.

@ogrisel
Copy link
Member

ogrisel commented Apr 23, 2024

I'm not sure how important score is though. There seem to be some usages of it, but curiously there seem to be no examples in scikit-learn itself on using score.

The score method is implicitly used by tools such as cross_val_score or GridSearchCV. However, it is true that it is very rarely used for PCA alone in practice. It's mostly used for supervised learning pipelines.

I agree that it's not ideal to have dask not be able to use score, but this is something that I think is reasonable to have users work around for now, e.g. by using _estimator_with_converted_arrays, to convert the estimator from dask to numpy arrays.
(similar to how we transfer arrays from GPU to CPU on cupy)

It might also be a case the score method itself of the estimator can convert arrays to numpy if the namespace does not provide the necessary xp.linalg.slogdet method. For truncated PCA, we can expect the call of the fit method and maybe get_precision to be slow and would deserve running on accelerated namespace while the final xp.linalg.slogdet call should not be a performance critical operation (and the result is a scalar).

Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>
@@ -288,11 +288,14 @@ def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False):
def _solve_svd(X, y, alpha, xp=None):
xp, _ = get_namespace(X, xp=xp)
U, s, Vt = xp.linalg.svd(X, full_matrices=False)
idx = s > 1e-15 # same default value as scipy.linalg.pinv
s_nnz = s[idx][:, None]
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

> does not have precedence of over [], right? Trying this pattern locally yields:

TypeError: 'float' object is not subscriptable

I think you meant the following instead:

Suggested change
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv
# scipy.linalg.pinv also thresholds at 1e-15 by default.
idx = (s > 1e-15)[:, None]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then the following s = s[:, None] seems redundant and similarly for idx[:, None] in the call to where.

Also we should probably rename idx to something more correct and explicit such as strictly_positive_mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think I had this working locally and passing tests, but I must have messed something up on the merge.

I'm planning on circling back to this towards the end of summer, so feel free to take over this if you're interested in the meantime.

Not holding my breath, but also hoping that dask.array support improves in the meantime as well.

(The recent Array API updates suggest that sort is the most pressing thing that's missing in dask. Linalg wise, the eig family of methods is probably the next biggest missing feature in dask, we just haven't seen it come up yet since not a lot of estimators have been ported yet).

@ogrisel ogrisel marked this pull request as draft May 24, 2024 13:59
@ogrisel
Copy link
Member

ogrisel commented May 24, 2024

Let's convert to draft for the time being then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

None yet

6 participants