Add N-ary broadcasting operations. #98
Conversation
|
||
assert_eq(sparse.elemwise(func, xs, ys, zs), func(x, y, z)) | ||
|
||
|
mrocklin
Feb 10, 2018
Collaborator
There are some extra checks in the removed tests that we may want to maintain, for example that the result of elemwise is a COO
, and that its non-zeros are as expected
There are some extra checks in the removed tests that we may want to maintain, for example that the result of elemwise is a COO
, and that its non-zeros are as expected
mrocklin
Feb 10, 2018
Collaborator
We might also consider having tests for some of the following:
- N-ary broadcasting where the arguments have different shapes
- N-ary broadcasting including arguments that are scalars and zero-dimensional arrays
We might also consider having tests for some of the following:
- N-ary broadcasting where the arguments have different shapes
- N-ary broadcasting including arguments that are scalars and zero-dimensional arrays
|
||
__doc__ = func.__doc__ | ||
|
||
return Partial() |
mrocklin
Feb 10, 2018
Collaborator
Thoughts on replacing this with a just a functools.partial
on top of a normal function?
This is our solution for dask
def partial_by_order(*args, **kwargs):
"""
>>> from operator import add
>>> partial_by_order(5, function=add, other=[(1, 10)])
15
"""
function = kwargs.pop('function')
other = kwargs.pop('other')
args2 = list(args)
for i, arg in other:
args2.insert(i, arg)
return function(*args2, **kwargs)
Thoughts on replacing this with a just a functools.partial
on top of a normal function?
This is our solution for dask
def partial_by_order(*args, **kwargs):
"""
>>> from operator import add
>>> partial_by_order(5, function=add, other=[(1, 10)])
15
"""
function = kwargs.pop('function')
other = kwargs.pop('other')
args2 = list(args)
for i, arg in other:
args2.insert(i, arg)
return function(*args2, **kwargs)
hameerabbasi
Feb 10, 2018
Author
Collaborator
I could have, but our situation is slightly unique:
- We're using
str(func)
for exceptions. functools.wraps
doesn't work on that for all callables (e.g. ufunc
s), and breaks a few docstrings. This leads to illegible names in exceptions like _posarg_partial.<locals>.wrapper
(and the same for debugging).
- We're replacing a number of arguments in different positions.
I guess I could turn it into a class rather than a decorator style function.
I could have, but our situation is slightly unique:
- We're using
str(func)
for exceptions.functools.wraps
doesn't work on that for all callables (e.g.ufunc
s), and breaks a few docstrings. This leads to illegible names in exceptions like_posarg_partial.<locals>.wrapper
(and the same for debugging). - We're replacing a number of arguments in different positions.
I guess I could turn it into a class rather than a decorator style function.
hameerabbasi
Feb 11, 2018
Author
Collaborator
I turned it into a callable class.
I turned it into a callable class.
@@ -2426,80 +2468,39 @@ def _elemwise_unary(func, self, *args, **kwargs): | |||
sorted=self.sorted) | |||
|
|||
|
|||
def _get_matching_coords(coords1, coords2, shape1, shape2): | |||
def _get_nary_matching_coords(coords, params, shape): |
mrocklin
Feb 10, 2018
Collaborator
Maybe just call this _get_matching_coords
and drop the nary. Presumably there wll be no need to distinguish any longer.
Maybe just call this _get_matching_coords
and drop the nary. Presumably there wll be no need to distinguish any longer.
hameerabbasi
Feb 11, 2018
Author
Collaborator
Done!
Done!
matching_coords : np.ndarray | ||
The coordinates of the output array for which both inputs will be nonzero. | ||
numpy.ndarray | ||
The broacasted coordinates. |
mrocklin
Feb 10, 2018
Collaborator
Style nit, there is no need to place a period at the end of a phrase like this. We tend to reserve periods for full sentences.
Style nit, there is no need to place a period at the end of a phrase like this. We tend to reserve periods for full sentences.
hameerabbasi
Feb 11, 2018
Author
Collaborator
Done!
Done!
result_shape = _get_broadcast_shape(self.shape, other.shape) | ||
Parameters | ||
---------- | ||
args : tuple[COO] |
mrocklin
Feb 10, 2018
Collaborator
If you're trying for parametrized python type annotations then I think it's supposed to be standard to use capitalized types like List[COO]
or Tuple[np.ndarray]
If you're trying for parametrized python type annotations then I think it's supposed to be standard to use capitalized types like List[COO]
or Tuple[np.ndarray]
mrocklin
Feb 10, 2018
Collaborator
I don't know though, this is somewhat new to me.
I don't know though, this is somewhat new to me.
hameerabbasi
Feb 10, 2018
Author
Collaborator
tuple
and list
tend to work better with intersphinx and code type annotations, so I tend to prefer those. Of course, I could import in something, but then that gives me PEP8 failures as I don't use it in code, just in docstrings.
tuple
and list
tend to work better with intersphinx and code type annotations, so I tend to prefer those. Of course, I could import in something, but then that gives me PEP8 failures as I don't use it in code, just in docstrings.
other_data = other_data[i] | ||
# Filter out scalars as they are 'baked' into the function. | ||
func = _posarg_partial(func, pos, posargs) | ||
args = list(filter(lambda arg: not isscalar(arg), args)) |
mrocklin
Feb 10, 2018
Collaborator
You might consider toolz.remove
here
You might consider toolz.remove
here
hameerabbasi
Feb 10, 2018
Author
Collaborator
I'd prefer not to introduce a dependency for something as simple as this.
I'd prefer not to introduce a dependency for something as simple as this.
args = list(args) | ||
posargs = [] | ||
pos = [] | ||
for i in range(len(args)): |
mrocklin
Feb 10, 2018
Collaborator
You might consider for i, arg in enumerate(args)
, which might be a bit more idiomatic for Python readers
You might consider for i, arg in enumerate(args)
, which might be a bit more idiomatic for Python readers
hameerabbasi
Feb 11, 2018
Author
Collaborator
Done!
Done!
@@ -1954,6 +1894,20 @@ def tril(x, k=0): | |||
return COO(coords, data, x.shape, x.has_duplicates, x.sorted) | |||
|
|||
|
|||
def _nary_match(*arrays): |
mrocklin
Feb 10, 2018
Collaborator
I'm not able to quickly figure out what this function does. Can I ask you for a small docstring? If possible I find small example sections in docstrings to be very helpful when learning codebases that others have written.
I'm not able to quickly figure out what this function does. Can I ask you for a small docstring? If possible I find small example sections in docstrings to be very helpful when learning codebases that others have written.
hameerabbasi
Feb 10, 2018
Author
Collaborator
Whoops, must have missed that one.
Whoops, must have missed that one.
mrocklin
Feb 17, 2018
Collaborator
It looks like this function is no longer used. Delete?
It looks like this function is no longer used. Delete?
ci, di = _unmatch_coo(func, args, mask, **kwargs) | ||
|
||
coords_list.extend(ci) | ||
data_list.extend(di) |
mrocklin
Feb 10, 2018
Collaborator
This confuses me and seems concerning. I see that this was a main point of your conversation with @shoyer earlier. I probably have some thinking to do on this problem before I'm able to reasonably comment on this.
This confuses me and seems concerning. I see that this was a main point of your conversation with @shoyer earlier. I probably have some thinking to do on this problem before I'm able to reasonably comment on this.
I think that it would be good to see a more comprehensive test suite that fully explains the complexity of what we're trying to accomplish here. I think that that will make it more clear as we discuss different possibilities here. We might ask "why are we doing X" and the answer can be "see test_X". I get the sense that you've thought deeply about this problem and know all of the problems that might arise. It would be very valuable to encode that deep thinking and all of those corner cases into a test suite. |
I plan to make more comprehensive tests, yes. But the issue is some of the complexity can't be directly tested: For example, the optimizations are just that: Optimizations. We can design the tests so the optimizations are hit but we can't know that they kicked in without weird monkey-patching of some sort. |
03a552f
to
84e4e6c
It seems there's a slight bug for number of inputs >2 and broadcasting, nothing unfixable, but will have to think a bit. I'm on it. |
I think that there are probably a lot of correctness tests that could be written as well. In #1 you discussed many situations that might arise for which a system like this would be necessary to catch. Ideally we would encode all of those situations as tests to ensure that future developers don't change code to alter correct behavior here. |
other_data = other_data[i] | ||
# Filter out scalars as they are 'baked' into the function. | ||
func = PositinalArgumentPartial(func, pos, posargs) | ||
args = list(filter(lambda arg: not isscalar(arg), args)) |
shoyer
Feb 11, 2018
Member
optional: consider using a list comprehension instead
optional: consider using a list comprehension instead
hameerabbasi
Feb 12, 2018
Author
Collaborator
Done!
Done!
matched_coords : np.ndarray | ||
The overall coordinates that match from both arrays. | ||
args : tuple[COO] | ||
The input :obj:`COO` arrays. |
shoyer
Feb 11, 2018
Member
add in func
, mask
and **kwargs
to the docstring?
add in func
, mask
and **kwargs
to the docstring?
hameerabbasi
Feb 12, 2018
Author
Collaborator
Done!
Done!
|
||
coords_list = [] | ||
data_list = [] | ||
pos, = np.where([not m for m in mask]) |
shoyer
Feb 11, 2018
Member
maybe use np.flatnonzero()
?
maybe use np.flatnonzero()
?
hameerabbasi
Feb 12, 2018
Author
Collaborator
This isn't really a numerical operation. I've converted it to a tuple(generator comprehension)
form and avoided np.where
altogether. The exact code is
pos = tuple(i for i, m in enumerate(mask) if not m)
This isn't really a numerical operation. I've converted it to a tuple(generator comprehension)
form and avoided np.where
altogether. The exact code is
pos = tuple(i for i, m in enumerate(mask) if not m)
I agree with @mrocklin that a more extensive test suite is vital here. This logic is complicated and fixing bugs later will be hard. I haven't seriously tried to follow it yet. I would suggest parametric tests verifying proper broadcasting with 2 or 3 arguments with:
|
d96e178
to
a715e9c
Some small coverage comments |
pos.append(i) | ||
posargs.append(args[i]) | ||
elif isinstance(arg, SparseArray) and not isinstance(arg, COO): | ||
args[i] = COO(arg) |
mrocklin
Feb 17, 2018
Collaborator
This line doesn't get hit by tests. Should we add a small DOK test?
This line doesn't get hit by tests. Should we add a small DOK test?
hameerabbasi
Feb 17, 2018
Author
Collaborator
Added!
Added!
posargs.append(args[i]) | ||
elif isinstance(arg, SparseArray) and not isinstance(arg, COO): | ||
args[i] = COO(arg) | ||
elif not isinstance(arg, COO): | ||
raise ValueError("Performing this operation would produce " | ||
"a dense result: %s" % str(func)) |
mrocklin
Feb 17, 2018
Collaborator
Same here. No test triggers this error-handling code.
Same here. No test triggers this error-handling code.
hameerabbasi
Feb 17, 2018
Author
Collaborator
Added a small test that hits this.
Added a small test that hits this.
args = [arg for arg in args if not isscalar(arg)] | ||
|
||
if len(args) == 0: | ||
return func(**kwargs) |
mrocklin
Feb 17, 2018
Collaborator
Also here. No test operates on no args
Also here. No test operates on no args
hameerabbasi
Feb 17, 2018
Author
Collaborator
Added another small test for this.
Added another small test for this.
raise ValueError('Unknown kwargs %s' % kwargs.keys()) | ||
|
||
if return_midx and (len(args) != 2 or cache is not None): | ||
raise NotImplementedError('Matching only supported for two args, and no cache.') |
mrocklin
Feb 17, 2018
Collaborator
Do we still need this option?
Do we still need this option?
hameerabbasi
Feb 17, 2018
Author
Collaborator
No, we don't. I'm not omniscient, so I went ahead and added this check in case someone tried to trigger caching on return_midx
(which we don't cache, it's never repeated); or tried to match indices for len(args) != 2
(I'm not sure if we'll need this in the future, but we might, and it's useful to err rather than have it return incorrect results).
No, we don't. I'm not omniscient, so I went ahead and added this check in case someone tried to trigger caching on return_midx
(which we don't cache, it's never repeated); or tried to match indices for len(args) != 2
(I'm not sure if we'll need this in the future, but we might, and it's useful to err rather than have it return incorrect results).
fs = sparse.elemwise(func, *args) | ||
assert isinstance(fs, COO) | ||
|
||
assert_eq(fs, func(*dense_args)) |
mrocklin
Feb 17, 2018
Collaborator
It would be nice to test and verify that we are not creating unnecessary zeroes in the data attribute. We might either test that explicitly here, or we might put it into assert_eq
. I've gone ahead and pushed a commit to your branch that adds a check into assert_eq
. Please remove if you prefer not to add this here.
It would be nice to test and verify that we are not creating unnecessary zeroes in the data attribute. We might either test that explicitly here, or we might put it into assert_eq
. I've gone ahead and pushed a commit to your branch that adds a check into assert_eq
. Please remove if you prefer not to add this here.
hameerabbasi
Feb 17, 2018
Author
Collaborator
I'd like to verify we don't create additional zeros for all our operations, so that seems like a rather useful addition.
I'd like to verify we don't create additional zeros for all our operations, so that seems like a rather useful addition.
hameerabbasi
Feb 17, 2018
•
Author
Collaborator
Although I would prefer to use np.count_nonzero
.
Edit: I reconsidered, this might be more useful for fill values.
Although I would prefer to use np.count_nonzero
.
Edit: I reconsidered, this might be more useful for fill values.
shoyer
Feb 17, 2018
Member
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
def value_array(n): | ||
ar = np.empty((n,), dtype=np.float_) | ||
ar[:] = value | ||
return ar |
mrocklin
Feb 17, 2018
Collaborator
We might want just a few of the values to be pathological instead of all of them.
We might want just a few of the values to be pathological instead of all of them.
hameerabbasi
Feb 17, 2018
Author
Collaborator
I'll modify the test to match that.
I'll modify the test to match that.
hameerabbasi
Feb 17, 2018
Author
Collaborator
Modified.
Modified.
I've incorporated more or less all of your suggestions about coverage, with one exception (see comments!) |
@shoyer do you have a chance to look at this? "Nope" is a fine answer. |
I'm guessing @shoyer doesn't work weekends. :-) If there's no reply or a "Nope" by the end of Monday, we can decide what to do next. |
It's a mixed bag on weekends, but this weekend my wife is away so I have time for open source :). I'll take a look. |
(2,), | ||
(3, 2), | ||
(4, 3, 2), | ||
], lambda x, y, z: (x + y) * z), |
shoyer
Feb 17, 2018
Member
Consider doing a full cross-product of shapes and functions here.
Consider doing a full cross-product of shapes and functions here.
hameerabbasi
Feb 17, 2018
Author
Collaborator
Done!
Done!
(4, 4), | ||
(4, 4, 4), | ||
], lambda x, y, z: x - y + z), | ||
]) |
shoyer
Feb 17, 2018
Member
It would be good to add checks for a few more variations on the broadcasting logic to exercise the matching logic:
- Dimensions of size 1, e.g.,
(3, 1)
+ (3, 4)
-> (3, 4)
- Output shapes that don't match one of the inputs, e.g.,
(3, 1)
+ (1, 4)
-> (3, 4)
.
- Outputs that require matching across three inputs, e.g.,
(1, 1, 2)
+ (1, 3, 1)
+ (4, 1, 1)
-> (4, 3, 2)
.
It would be good to add checks for a few more variations on the broadcasting logic to exercise the matching logic:
- Dimensions of size 1, e.g.,
(3, 1)
+(3, 4)
->(3, 4)
- Output shapes that don't match one of the inputs, e.g.,
(3, 1)
+(1, 4)
->(3, 4)
. - Outputs that require matching across three inputs, e.g.,
(1, 1, 2)
+(1, 3, 1)
+(4, 1, 1)
->(4, 3, 2)
.
hameerabbasi
Feb 17, 2018
Author
Collaborator
The first two were already covered in test_broadcasting
. I renamed that to test_binary_broadcasting
and moved it closer to these.
The third, I also added.
The first two were already covered in test_broadcasting
. I renamed that to test_binary_broadcasting
and moved it closer to these.
The third, I also added.
fs = sparse.elemwise(func, *args) | ||
assert isinstance(fs, COO) | ||
|
||
assert_eq(fs, func(*dense_args)) |
shoyer
Feb 17, 2018
Member
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
I can't seem to be able to respond to your "sparse" broadcasting comment, so I'm responding here. I monkey-patched one of our own functions and verified the behavior is correct there. I also verified Edit: However; I will add that like all monkey patching, it's implementation dependent, not (just) API dependent. |
Yeah, I was trying to do that as well. I haven't seen that before Is checking for the right number of non-zeros in the output not sufficient? Do we have code paths that would re-sparsify a dense intermediate result? |
We're actually talking about the "optimized" code path for things like |
Ah, happy to retract my comment |
No, let's leave it up for other people who can't follow the terminology. |
Are there any further comments or is this good to merge? |
I haven't reviewed the logic in detail, but the implementation looks relatively sane and I am satisfied with the test coverage. |
Merged! |
This PR adds N-ary broadcasting operations (in preparation for where) and simplifies code for the N-ary case.
This PR adds N-ary broadcasting operations (in preparation for
where
) and simplifies code for the N-ary case.Discussed in #1