Skip to content

Commit

Permalink
fix: preserve dimensions for keepdims=True, axis=None reductions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Jan 31, 2023
1 parent e9173b7 commit a91d96b
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 89 deletions.
16 changes: 10 additions & 6 deletions src/awkward/_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,31 +217,34 @@ def pad_none(
return layout._pad_none(length, axis, 1, clip)


def completely_flatten(
def remove_structure(
layout: Content | Record,
backend: Backend | None = None,
flatten_records: bool = True,
function_name: str | None = None,
drop_nones: bool = True,
keepdims: bool = False,
):
if isinstance(layout, Record):
return completely_flatten(
return remove_structure(
layout._array[layout._at : layout._at + 1],
backend,
flatten_records,
function_name,
drop_nones,
keepdims,
)

else:
if backend is None:
backend = layout._backend
arrays = layout._completely_flatten(
arrays = layout._remove_structure(
backend,
{
"flatten_records": flatten_records,
"function_name": function_name,
"drop_nones": drop_nones,
"keepdims": keepdims,
},
)
return tuple(arrays)
Expand Down Expand Up @@ -314,15 +317,16 @@ def reduce(
behavior: dict | None = None,
):
if axis is None:
parts = completely_flatten(layout, flatten_records=False, drop_nones=False)
parts = remove_structure(
layout, flatten_records=False, drop_nones=False, keepdims=keepdims
)

if len(parts) > 1:
# We know that `flatten_records` must fail, so the only other type
# that can return multiple parts here is the union array
raise ak._errors.wrap_error(
ValueError(
"cannot use axis=None with keepdims=True on an array containing "
"irreducible unions"
"cannot use axis=None on an array containing irreducible unions"
)
)
elif len(parts) == 0:
Expand Down
9 changes: 8 additions & 1 deletion src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ def arrays_approx_equal(
atol: float = 1e-8,
dtype_exact: bool = True,
check_parameters=True,
check_regular=True,
) -> bool:
# TODO: this should not be needed after refactoring nplike mechanism
import awkward.forms.form
Expand Down Expand Up @@ -798,7 +799,13 @@ def visitor(left, right) -> bool:
right = right.to_IndexedOptionArray64()

if type(left) is not type(right):
return False
if not check_regular and (
left.is_list and right.is_regular or left.is_regular and right.is_list
):
left = left.to_ListOffsetArray64()
right = right.to_ListOffsetArray64()
else:
return False

if left.length != right.length:
return False
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_ByteMaskedArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,10 +972,10 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_IndexedOptionArray64()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def drop_none(self):
def _drop_none(self) -> Content:
raise ak._errors.wrap_error(NotImplementedError)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
raise ak._errors.wrap_error(NotImplementedError)

def _recursively_apply(
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return backend.nplike.empty(0, dtype=np.float64)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
return []

def _recursively_apply(
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,8 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.project()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
return self.project()._completely_flatten(backend, options)
def _remove_structure(self, backend, options):
return self.project()._remove_structure(backend, options)

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def _reduce_next(
"reduce_next with unbranching depth > negaxis is only "
"expected to return RegularArray or ListOffsetArray or "
"IndexedOptionArray; "
"instead, it returned " + out
"instead, it returned {}".format(type(out).__name__)
)
)

Expand Down Expand Up @@ -1526,10 +1526,10 @@ def _to_backend_array(self, allow_missing, backend):
else:
return content

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
12 changes: 2 additions & 10 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,16 +1379,8 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_RegularArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
):
return [self]
else:
next = self.to_ListOffsetArray64(False)
flat = next.content[next.offsets[0] : next.offsets[-1]]
return flat._completely_flatten(backend, options)
def _remove_structure(self, backend, options):
return self.to_ListOffsetArray64(False)._remove_structure(backend, options)

def _drop_none(self):
return self.to_ListOffsetArray64()._drop_none()
Expand Down
21 changes: 18 additions & 3 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,15 +1962,30 @@ def _to_backend_array(self, allow_missing, backend):

return self.to_RegularArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
):
return [self]
else:
flat = self._content[self._offsets[0] : self._offsets[-1]]
return flat._completely_flatten(backend, options)
content = self._content[self._offsets[0] : self._offsets[-1]]
contents = content._remove_structure(backend, options)
if options["keepdims"]:
return [
ListOffsetArray(
ak.index.Index64(
backend.index_nplike.asarray(
[0, backend.index_nplike.shape_item_as_scalar(c.length)]
)
),
c,
parameters=self._parameters,
)
for c in contents
]
else:
return contents

def _drop_none(self):
if self._content.is_option:
Expand Down
8 changes: 6 additions & 2 deletions src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,10 +1204,14 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return to_nplike(self.data, backend.nplike, from_nplike=self._backend.nplike)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if options["keepdims"]:
shape = (1,) * (self._data.ndim - 1) + (-1,)
else:
shape = (-1,)
return [
ak.contents.NumpyArray(
backend.nplike.reshape(self._raw(backend.nplike), (-1,)),
backend.nplike.reshape(self._raw(backend.nplike), shape),
backend=backend,
)
]
Expand Down
6 changes: 2 additions & 4 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,13 +922,11 @@ def _to_backend_array(self, allow_missing, backend):

return out

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if options["flatten_records"]:
out = []
for content in self._contents:
out.extend(
content[: self._length]._completely_flatten(backend, options)
)
out.extend(content[: self._length]._remove_structure(backend, options))
return out
else:
in_function = ""
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
),
)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
Expand All @@ -1227,8 +1227,15 @@ def _completely_flatten(self, backend, options):
else:
index_nplike = self._backend.index_nplike
length = index_nplike.mul_shape_item(self._length, self._size)
flat = self._content[: index_nplike.shape_item_as_scalar(length)]
return flat._completely_flatten(backend, options)
content = self._content[: index_nplike.shape_item_as_scalar(length)]
contents = content._remove_structure(backend, options)
if options["keepdims"]:
return [
RegularArray(c, size=c.length, parameters=self._parameters)
for c in contents
]
else:
return contents

def _drop_none(self):
return self.to_ListOffsetArray64()._drop_none()
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,14 +1469,14 @@ def _to_backend_array(self, allow_missing, backend):

return out

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
out = []
for i in range(len(self._contents)):
index = self._index[self._tags.data == i]
out.extend(
self._contents[i]
._carry(index, False)
._completely_flatten(backend, options)
._remove_structure(backend, options)
)
return out

Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,10 @@ def _to_backend_array(self, allow_missing, backend):
else:
return content

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _impl(array, axis, highlevel, behavior):
layout = ak.operations.to_layout(array, allow_record=False, allow_other=False)

if axis is None:
out = ak._do.completely_flatten(layout, function_name="ak.flatten")
out = ak._do.remove_structure(layout, function_name="ak.flatten")
assert isinstance(out, tuple) and all(
isinstance(x, ak.contents.NumpyArray) for x in out
)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_ravel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def ravel(array, *, highlevel=True, behavior=None):
def _impl(array, highlevel, behavior):
layout = ak.operations.to_layout(array, allow_record=False, allow_other=False)

out = ak._do.completely_flatten(layout, function_name="ak.ravel", drop_nones=False)
out = ak._do.remove_structure(layout, function_name="ak.ravel", drop_nones=False)
assert isinstance(out, tuple) and all(
isinstance(x, ak.contents.Content) for x in out
)
Expand Down
Loading

0 comments on commit a91d96b

Please sign in to comment.