-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: support merging of RegularArray
and NumpyArray
#2063
Conversation
# Default merging (can we cast one to the other) | ||
else: | ||
return self.backend.nplike.can_cast( | ||
self.dtype, other.dtype, casting="same_kind" | ||
) or self.backend.nplike.can_cast( | ||
other.dtype, self.dtype, casting="same_kind" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea here is that np.concatenate
uses something like can_cast
to test mergeability. I think this should predict np.concatenate
, at least (I'm not certain).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's specifically this part of the code (line 350 to here) that was tuned to give the right merging behavior for datetime and timedelta types. It was originally wrong; I fixed it here:
0734ab2#diff-cdc2b042962b124ea98cc8bacd6aea030a50dc73f6502f9686475554c8913320
("Wrong" was mixing datetimes or timedeltas of different units without conversion.)
As I said in today's meeting, I'd like to continue being stricter than NumPy, not allowing datetimes or timedeltas to be mixed with non-datetimes or non-timedeltas. When the Array API Consortium eventually gets around to defining time-unit behavior, I suspect that they'll come to the same conclusion: promoting non-time data to act as time-data means guessing a unit, which usually isn't safe. If I'm right and they do pick this behavior, we won't have to change it. (Since it's a behavior without an API marker, like a function argument, it would be hard to set up a deprecation cycle.)
So let's revert the use of nplike.can_cast
, at least for time units or any conversions that are already in violation with Array API decisions. We can diverge from current-NumPy if we think doing so means making smaller or no changes later, when NumPy is compliant with Array API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised, though, that the above didn't break any tests. I remember there being some tests that were sensitive to the time-unit fix, but I can't find them now.
src/awkward/contents/numpyarray.py
Outdated
def _to_regular_primitive(self): | ||
index = tuple([slice(None, 1)] * len(self.shape)) | ||
new_data = self.backend.nplike.broadcast_to(self._data[index], self.shape) | ||
return NumpyArray( | ||
new_data, backend=self.backend, parameters=self.parameters | ||
).to_RegularArray() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a terrible name; this PR is a PoC and I'd appreciate a better suggestion. Maybe _to_mergeable_hack
;)
@property | ||
def issubdtype(self): | ||
return numpy.issubdtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issubdtype
behaves like a method descriptor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just formal, for code symmetry? Whether it's an attribute or a property on the NumpyMetadata.instance()
singleton, issubdtype
acts like a function (it doesn't consume a self
argument).
N.B that the tests will fail, because this changes merging behavior for date-times (in line with NumPy AFAICT). |
@@ -989,7 +989,7 @@ def unique(self, *args, **kwargs): | |||
try_touch_data(x) | |||
raise ak._errors.wrap_error(NotImplementedError) | |||
|
|||
def concatenate(self, arrays): | |||
def concatenate(self, arrays, casting="same_kind"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not right - we should actually consider this variable in concatenate
. However, this would fall under the typetracer / nplike refactoring that I'm also working on separately, so I'll handle it there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
TODO: restore existing datetime mergeability test (stricter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few comments, because this is in the draft stage.
@property | ||
def issubdtype(self): | ||
return numpy.issubdtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just formal, for code symmetry? Whether it's an attribute or a property on the NumpyMetadata.instance()
singleton, issubdtype
acts like a function (it doesn't consume a self
argument).
@@ -989,7 +989,7 @@ def unique(self, *args, **kwargs): | |||
try_touch_data(x) | |||
raise ak._errors.wrap_error(NotImplementedError) | |||
|
|||
def concatenate(self, arrays): | |||
def concatenate(self, arrays, casting="same_kind"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
if isinstance( | ||
elif isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original idea of having the option-type/index check be a separate if
statement from the subsequent chain is that it's conceptually distinct. Most node types, across the Content
class hierarchy, start by asking this same question before getting into questions that are specific to each node type.
But since the option-type/index if
block ends by returning, if
and elif
are functionally equivalent here, and looking at any one node's _mergeable_next
implementation in isolation, a single if
chain would be simpler.
if len(self.shape) > 1: | ||
return self._to_regular_primitive()._mergeable(other, mergebool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main part of the PR.
In our conversation, I was advocating for
assert len(self.shape) == 1
here, with an upstream conversion of the whole layout to replace NumpyArray
with RegularArray
before it enters into mergeable
and mergemany
.
I'm assuming that the NumpyArray
would have to be converted (often a copy) into RegularArray
again if mergeable
returns True
and the mergemany
step happens. I was just hoping that could be made a single conversion. (It's not a big deal if it isn't.)
Part of that reasoning is that mergeable
/mergemany
is an internal API that we have to be careful with, anyway. (Calling mergemany
on data that would have mergeable == False
is undefined; I suspect a segfault would result.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conversion is low cost, as only for JAX do we actually need to copy any data. Is your point about wanting an assert
here still necessary in spite of that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We concluded from some experiments (and hindsight) that when reshape
reduces the number of dimensions, it does have to copy, for any backend.
>>> a = np.arange(3*5*7).reshape(3, 5, 7)
>>> b = a[1:, 1:, 1:]
>>> c = a.T
>>> a[-1, -1, -1] = 123
>>> b
array([[[ 43, 44, 45, 46, 47, 48],
[ 50, 51, 52, 53, 54, 55],
[ 57, 58, 59, 60, 61, 62],
[ 64, 65, 66, 67, 68, 69]],
[[ 78, 79, 80, 81, 82, 83],
[ 85, 86, 87, 88, 89, 90],
[ 92, 93, 94, 95, 96, 97],
[ 99, 100, 101, 102, 103, 123]]])
>>> c # also has a "123" in the last element
>>> b_flat = b.reshape((-1,) + b.shape[2:])
>>> c_flat = c.reshape((-1,) + c.shape[2:])
>>> a[-1, -1, -1] = 321
>>> b_flat
array([[ 43, 44, 45, 46, 47, 48],
[ 50, 51, 52, 53, 54, 55],
[ 57, 58, 59, 60, 61, 62],
[ 64, 65, 66, 67, 68, 69],
[ 78, 79, 80, 81, 82, 83],
[ 85, 86, 87, 88, 89, 90],
[ 92, 93, 94, 95, 96, 97],
[ 99, 100, 101, 102, 103, 123]])
>>> c_flat # also does not have a "321" in the last element
It's because when a dimension is removed, there's nowhere for NumPy to insert the information that there should be a different byte-skip between the affected numbers.
So this NumpyArray
→ RegularArray
conversion is an
JAX had some different performance cost, which NumPy and CuPy didn't have, which I can't remember right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, if we use to_RegularArray()
we would see this penalty (for non-contiguous arrays).
Hence, I wrote this function, which provides a hacky solution given that we don't care about the buffer contents; scale the array to a single scalar, and then broadcast it to the appropriate dimensions. This creates an array with strides = (0, 0, ..., 0)
, which consumes no additional memory. Broadcasting like this doesn't work for JAX, because JAX doesn't have views over buffers in the same way that NumPy, JAX do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, yeah, that's cheap-as-free! Now I remember that this, specifically, is the case that JAX doesn't handle as well as NumPy and CuPy. And it's fine to ignore this performance issue in JAX (it's a corner case of a corner case).
# Default merging (can we cast one to the other) | ||
else: | ||
return self.backend.nplike.can_cast( | ||
self.dtype, other.dtype, casting="same_kind" | ||
) or self.backend.nplike.can_cast( | ||
other.dtype, self.dtype, casting="same_kind" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's specifically this part of the code (line 350 to here) that was tuned to give the right merging behavior for datetime and timedelta types. It was originally wrong; I fixed it here:
0734ab2#diff-cdc2b042962b124ea98cc8bacd6aea030a50dc73f6502f9686475554c8913320
("Wrong" was mixing datetimes or timedeltas of different units without conversion.)
As I said in today's meeting, I'd like to continue being stricter than NumPy, not allowing datetimes or timedeltas to be mixed with non-datetimes or non-timedeltas. When the Array API Consortium eventually gets around to defining time-unit behavior, I suspect that they'll come to the same conclusion: promoting non-time data to act as time-data means guessing a unit, which usually isn't safe. If I'm right and they do pick this behavior, we won't have to change it. (Since it's a behavior without an API marker, like a function argument, it would be hard to set up a deprecation cycle.)
So let's revert the use of nplike.can_cast
, at least for time units or any conversions that are already in violation with Array API decisions. We can diverge from current-NumPy if we think doing so means making smaller or no changes later, when NumPy is compliant with Array API.
# Default merging (can we cast one to the other) | ||
else: | ||
return self.backend.nplike.can_cast( | ||
self.dtype, other.dtype, casting="same_kind" | ||
) or self.backend.nplike.can_cast( | ||
other.dtype, self.dtype, casting="same_kind" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised, though, that the above didn't break any tests. I remember there being some tests that were sensitive to the time-unit fix, but I can't find them now.
It does, it breaks many tests! That was why I raised it today, because we clearly care about this.
The Python attribute resolution machinery will always invoke |
Codecov Report
Additional details and impacted files
|
@jpivarski I've restored the date-time handling that corresponds to the previous behaviour. Due to the changes in this PR, Previously, we just returned |
Okay, I'm on board with this now. My one performance question was addressed, the dates are as strict as they used to be, and calling
If it's actually less permissive, then in principle that would be backward incompatible (something that used to work now won't). Do you know of cases that we used to merge but would now be considered not mergeable? |
Yes, that was why I expounded upon this to assert that although it cannot be more permissive, I don't think it's less permissive either. It doesn't fail any tests, and I can't think of any types that would not be supported. Ooh, actually, let me try one combination. |
Okay, so if the one new test you're trying passes, then I think this PR is ready to be merged. |
Now that I think about it, |
This is a PoC to fix #2058. Jim and I also discussed doing this properly at the form level, which I might even explore in another draft PR.
main
casting semantics