New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH safe_sparse_dot work on 3D+ arrays #14538
ENH safe_sparse_dot work on 3D+ arrays #14538
Conversation
…atmul instead of dot when applicable, add tests
@@ -117,29 +117,40 @@ def density(w, **kwargs): | |||
def safe_sparse_dot(a, b, dense_output=False): | |||
"""Dot product that handle the sparse matrix case correctly | |||
Uses BLAS GEMM as replacement for numpy.dot where possible | |||
to avoid unnecessary copies. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed that: it does not do that.
if dense_output and hasattr(ret, "toarray"): | ||
ret = ret.toarray() | ||
return ret | ||
if a.ndim > 2 or b.ndim > 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this really work for ND arrays? I would restrict it to 3D, which is what we are using, and raise a NotImplementedError
otherwise. We are not testing 4D+ case anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does :)
I updated the tests to ensure that. I don't see any reason to raise an error since it's implemented and works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do we need 3d?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least in coordinate descent (see diff) between 3D dense and sparse.
ret = ret.toarray() | ||
return ret | ||
if a.ndim > 2 or b.ndim > 2: | ||
if sparse.issparse(a): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So implicitly here a is 2D, b 3D+ might be worth adding a comment
b_2d = b_.reshape((b.shape[-2], -1)) | ||
ret = a @ b_2d | ||
ret = ret.reshape(a.shape[0], *b_.shape[1:]) | ||
elif sparse.issparse(b): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
I wouldn't have been sure about supporting >=3d, but I'm happy with this. I'd like to change the title to not say FIX... it's a DOC fix because the documented behaviour is unreasonable ;)
Add what's new?
@rth I think I adressed your comments |
doc/whats_new/v0.22.rst
Outdated
sparse matrix. | ||
|Fix| The documented behavior of the ``dense_output`` parameter of | ||
:func:`utils.safe_sparse_dot` was not the intended behavior and now matches | ||
the actual behavior of the function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the second part is a documentation fix? Not sure it needs a what's new entry?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not sure about that. It's true that the behavior did not change, only it's documentation. I removed it.
safe_sparse_dot
docstring was not accurate regardingdense_output
parameter. If only one of them is sparse, it would always return a dense array.I don't think it was an intended behavior anyway since I saw several
which would break if the output was sparse. Moreover, it's very unlikely that dense @ sparse has many zeros. So I changed the docstring to match the behavior.
It did not handle the case when one of the arrays is 3D+ and the other is sparse.
When arrays are 1D or 2D, matmul is preferred over dot, and is a little bit faster. I replace dot by @ when applicable.
Function was untested.