Skip to content

Commit

Permalink
Fixed a few bugs in mergeable and simplify_uniontype.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 9, 2021
1 parent 5b0a273 commit 0734ab2
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/awkward/_v2/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return self._content.mergeable(other, mergebool)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_v2/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return self._content.mergeable(other, mergebool)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return self._content.mergeable(other, mergebool)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return self._content.mergeable(other, mergebool)
Expand Down
5 changes: 3 additions & 2 deletions src/awkward/_v2/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def mergeable(self, other, mergebool):
),
):
return True

if isinstance(
other,
(
Expand All @@ -729,7 +730,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self.mergeable(other.content, mergebool)
return self.mergeable(other.content, mergebool)

if isinstance(
other,
Expand All @@ -739,7 +740,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.listoffsetarray.ListOffsetArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return False
Expand Down
32 changes: 22 additions & 10 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,29 +357,41 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
return self.mergeable(other.content, mergebool)

if self._data.ndim == 0:
return False
return self.mergeable(other._content, mergebool)

if isinstance(other, ak._v2.contents.numpyarray.NumpyArray):
if self._data.ndim != other.data.ndim:
if self._data.ndim != other._data.ndim:
return False

if (
not mergebool
and self._data.dtype != other._data.dtype
and (
self._data.dtype.type is np.bool_
or other._data.dtype.type is np.bool_
)
):
return False

if self.dtype != other.dtype and (
self.dtype == np.datetime64 or other.dtype == np.datetime64
if self._data.dtype != other._data.dtype and (
self._data.dtype == np.datetime64 or other._data.dtype == np.datetime64
):
return False

if self.dtype != other.dtype and (
self.dtype == np.timedelta64 or other.dtype == np.timedelta64
if self._data.dtype != other._data.dtype and (
self._data.dtype == np.timedelta64
or other._data.dtype == np.timedelta64
):
return False

if len(self.shape) > 1 and len(self.shape) != (other.shape):
if (
len(self._data.shape) > 1
and self._data.shape[1:] != other._data.shape[1:]
):
return False

return True

else:
return False

Expand Down
35 changes: 21 additions & 14 deletions src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,25 +453,32 @@ def mergeable(self, other, mergebool=True):
):
return self.mergeable(other.content, mergebool)

if isinstance(other, ak._v2.contents.recordarray.RecordArray):
if isinstance(other, RecordArray):
if self.is_tuple and other.is_tuple:
if len(self.contents) == len(other.contents):
for i in range(len(self.contents)):
if not self.contents[i].mergeable(other.contents[i], mergebool):
if len(self._contents) == len(other._contents):
for i in range(len(self._contents)):
if not self._contents[i].mergeable(
other._contents[i], mergebool
):
return False
return True
else:
return True

elif not self.is_tuple and not other.is_tuple:
self_fields = self.fields.copy()
other_fields = other.fields.copy()
self_fields.sort()
other_fields.sort()
if self_fields == other_fields:
for field in self_fields:
if not self[field].mergeable(other[field], mergebool):
return False
if set(self._fields) != set(other._fields):
return False

for i, field in enumerate(self._fields):
x = self._contents[i]
y = other._contents[other.field_to_index(field)]
if not x.mergeable(y, mergebool):
return False
else:
return True
return False

else:
return False

else:
return False

Expand Down
1 change: 1 addition & 0 deletions src/awkward/_v2/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def _offsets_and_flattened(self, axis, depth):
def mergeable(self, other, mergebool):
if not _parameters_equal(self._parameters, other._parameters):
return False

if isinstance(
other,
(
Expand Down
70 changes: 38 additions & 32 deletions src/awkward/_v2/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,21 @@ def _getitem_next(self, head, tail, advanced):
raise AssertionError(repr(head))

def simplify_uniontype(self, merge=True, mergebool=False):
tags = ak._v2.index.Index8.empty(len(self), self.nplike)
index = ak._v2.index.Index64.empty(len(self), self.nplike)
if len(self._index) < len(self._tags):
raise ValueError("invalid UnionArray: len(index) < len(tags)")

nplike = self._tags.nplike
length = len(self._tags)
tags = ak._v2.index.Index8.empty(length, nplike)
index = ak._v2.index.Index64.empty(length, nplike)
contents = []
for i in range(len(self.contents)):
rawcontent = self.contents[i]
if isinstance(rawcontent, ak._v2.contents.unionarray.UnionArray):

innertags = rawcontent.tags
innerindex = rawcontent.index
innercontents = rawcontent.contents
for i in range(len(self._contents)):
if isinstance(self._contents[i], UnionArray):
innertags = self._contents[i]._tags
innerindex = self._contents[i]._index
innercontents = self._contents[i]._contents

for j in range(len(innercontents)):
unmerged = True
for k in range(len(contents)):
Expand All @@ -386,16 +391,16 @@ def simplify_uniontype(self, merge=True, mergebool=False):
innertags.dtype.type,
innerindex.dtype.type,
](
tags.to(self.nplike),
index.to(self.nplike),
self._tags.to(self.nplike),
self._index.to(self.nplike),
innertags.to(self.nplike),
innerindex.to(self.nplike),
tags.to(nplike),
index.to(nplike),
self._tags.to(nplike),
self._index.to(nplike),
innertags.to(nplike),
innerindex.to(nplike),
k,
j,
i,
len(self),
length,
len(contents[k]),
)
)
Expand All @@ -414,24 +419,25 @@ def simplify_uniontype(self, merge=True, mergebool=False):
innertags.dtype.type,
innerindex.dtype.type,
](
tags.to(self.nplike),
index.to(self.nplike),
self._tags.to(self.nplike),
self._index.to(self.nplike),
innertags.to(self.nplike),
innerindex.to(self.nplike),
tags.to(nplike),
index.to(nplike),
self._tags.to(nplike),
self._index.to(nplike),
innertags.to(nplike),
innerindex.to(nplike),
len(contents),
j,
i,
len(self),
length,
0,
)
)
contents.append(innercontents[j])

else:
unmerged = True
for k in range(len(contents)):
if contents[k] == self.contents[i]:
if contents[k] is self._contents[i]:
self._handle_error(
self.nplike[
"awkward_UnionArray_simplify_one",
Expand All @@ -446,14 +452,14 @@ def simplify_uniontype(self, merge=True, mergebool=False):
self._index.to(self.nplike),
k,
i,
len(self),
length,
0,
)
)
unmerged = False
break

elif merge and contents[k].mergeable(self.contents[i], mergebool):
elif merge and contents[k].mergeable(self._contents[i], mergebool):
self._handle_error(
self.nplike[
"awkward_UnionArray_simplify_one",
Expand All @@ -468,11 +474,11 @@ def simplify_uniontype(self, merge=True, mergebool=False):
self._index.to(self.nplike),
k,
i,
len(self),
length,
len(contents[k]),
)
)
contents[k] = contents[k].merge(self.contents[i])
contents[k] = contents[k].merge(self._contents[i])
unmerged = False
break

Expand All @@ -491,15 +497,16 @@ def simplify_uniontype(self, merge=True, mergebool=False):
self._index.to(self.nplike),
len(contents),
i,
len(self),
length,
0,
)
)

contents.append(self.contents[i])
contents.append(self._contents[i])

if len(contents) > 2 ** 7:
raise AssertionError("FIXME: handle UnionArray with more than 127 contents")
raise NotImplementedError(
"FIXME: handle UnionArray with more than 127 contents"
)

if len(contents) == 1:
return contents[0]._carry(index, True, NestedIndexError)
Expand Down Expand Up @@ -602,7 +609,6 @@ def _offsets_and_flattened(self, axis, depth):
def mergeable(self, other, mergebool):
if not _parameters_equal(self._parameters, other._parameters):
return False

return True

def merging_strategy(self, others):
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_v2/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def mergeable(self, other, mergebool):
ak._v2.contents.unmaskedarray.UnmaskedArray,
),
):
self._content.mergeable(other.content, mergebool)
return self._content.mergeable(other.content, mergebool)

else:
return self._content.mergeable(other, mergebool)
Expand Down

0 comments on commit 0734ab2

Please sign in to comment.