Skip to content

Commit

Permalink
fix: support conversion to arrow and back with non-option Unknown type (
Browse files Browse the repository at this point in the history
#3085)

This adds logic to add metadata to ExtensionArray objects in Arrow arrays on conversion from Awkward,
and sets nullability in the arrow objects with more specificity.

* Metadata added is "is_nonnullable_nulltype"
* Content base class has a class property: `_arrow_needs_option_type`. Normally this returns
  the value of is_option, but EmptyArray overrides this to return True iff a mask is being used.
  • Loading branch information
tcawlfield committed Apr 29, 2024
1 parent 5e05752 commit 21af5dc
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 19 deletions.
38 changes: 29 additions & 9 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ def __init__(
node_parameters,
record_is_tuple,
record_is_scalar,
is_nonnullable_nulltype=False,
):
self._mask_type = mask_type
self._node_type = node_type
self._mask_parameters = mask_parameters
self._node_parameters = node_parameters
self._record_is_tuple = record_is_tuple
self._record_is_scalar = record_is_scalar
self._is_nonnullable_nulltype = is_nonnullable_nulltype
super().__init__(storage_type, "awkward")

def __str__(self):
Expand Down Expand Up @@ -140,6 +142,7 @@ def __arrow_ext_serialize__(self):
"node_parameters": self._node_parameters,
"record_is_tuple": self._record_is_tuple,
"record_is_scalar": self._record_is_scalar,
"is_nonnullable_nulltype": self._is_nonnullable_nulltype,
}
).encode(errors="surrogatescape")

Expand All @@ -154,6 +157,7 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized):
metadata["node_parameters"],
metadata["record_is_tuple"],
metadata["record_is_scalar"],
is_nonnullable_nulltype=metadata.get("is_nonnullable_nulltype", False),
)

@property
Expand Down Expand Up @@ -610,10 +614,18 @@ def popbuffers(paarray, awkwardarrow_type, storage_type, buffers, generate_bitma
validbits = buffers.pop(0)
assert storage_type.num_fields == 0

empty_array = ak.contents.EmptyArray(
parameters=node_parameters(awkwardarrow_type)
)
if awkwardarrow_type is not None and awkwardarrow_type._is_nonnullable_nulltype:
# Special case: pyarrow does not support a non-option null type,
# So we short-cut the Option wrapper when _is_nonnullable_nulltype is True.
return revertable(empty_array, empty_array)

# This is already an option-type and offsets-corrected, so no popbuffers_finalize.
return ak.contents.IndexedOptionArray(
ak.index.Index64(numpy.full(len(paarray), -1, dtype=np.int64)),
ak.contents.EmptyArray(parameters=node_parameters(awkwardarrow_type)),
empty_array,
parameters=mask_parameters(awkwardarrow_type),
)

Expand Down Expand Up @@ -857,17 +869,25 @@ def form_popbuffers(awkwardarrow_type, storage_type):


def to_awkwardarrow_type(
storage_type, use_extensionarray, record_is_scalar, mask, node
storage_type,
use_extensionarray,
record_is_scalar,
mask,
node,
is_nonnullable_nulltype=False,
):
if use_extensionarray:
return AwkwardArrowType(
storage_type,
direct_Content_subclass_name(mask),
direct_Content_subclass_name(node),
None if mask is None else mask.parameters,
None if node is None else node.parameters,
node.is_tuple if isinstance(node, ak.contents.RecordArray) else None,
record_is_scalar,
storage_type=storage_type,
mask_type=direct_Content_subclass_name(mask),
node_type=direct_Content_subclass_name(node),
mask_parameters=None if mask is None else mask.parameters,
node_parameters=None if node is None else node.parameters,
record_is_tuple=node.is_tuple
if isinstance(node, ak.contents.RecordArray)
else None,
record_is_scalar=record_is_scalar,
is_nonnullable_nulltype=is_nonnullable_nulltype,
)
else:
return storage_type
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,10 @@ def _fill_none(self, value: Content) -> Content:
def copy(self, *, parameters: JSONMapping | None = UNSET) -> Self:
raise NotImplementedError

@classmethod
def _arrow_needs_option_type(cls):
return cls.is_option # is_option is a class property of Meta


@register_backend_lookup_factory
def find_content_backend(obj: type):
Expand Down
15 changes: 10 additions & 5 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,12 @@ def _to_arrow(
if options["emptyarray_to"] is None:
return pyarrow.Array.from_buffers(
ak._connect.pyarrow.to_awkwardarrow_type(
pyarrow.null(),
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
storage_type=pyarrow.null(),
use_extensionarray=options["extensionarray"],
record_is_scalar=options["record_is_scalar"],
mask=mask_node,
node=self,
is_nonnullable_nulltype=mask_node is None,
),
length,
[
Expand All @@ -386,6 +387,10 @@ def _to_arrow(
)
return next._to_arrow(pyarrow, mask_node, validbytes, length, options)

@classmethod
def _arrow_needs_option_type(cls):
return True # This overrides Content._arrow_needs_option_type

def _to_backend_array(self, allow_missing, backend):
return backend.nplike.empty(0, dtype=np.float64)

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,7 +1974,7 @@ def _to_arrow(
)

content_type = pyarrow.list_(paarray.type).value_field.with_nullable(
akcontent.is_option
akcontent._arrow_needs_option_type()
)

if issubclass(npoffsets.dtype.type, np.int32):
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def _to_arrow(
types = pyarrow.struct(
[
pyarrow.field(self.index_to_field(i), values[i].type).with_nullable(
x.is_option
x._arrow_needs_option_type()
)
for i, x in enumerate(self._contents)
]
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ def _to_arrow(
)

content_type = pyarrow.list_(paarray.type).value_field.with_nullable(
akcontent.is_option
akcontent._arrow_needs_option_type()
)

return pyarrow.Array.from_buffers(
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,8 @@ def _to_arrow(
types = pyarrow.union(
[
pyarrow.field(str(i), values[i].type).with_nullable(
mask_node is not None or self._contents[i].is_option
mask_node is not None
or self._contents[i]._arrow_needs_option_type()
)
for i in range(len(values))
],
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_to_arrow_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _impl(
# accounting for options above the record layout.
arrow_fields.append(
pyarrow.field(name, arrow_arrays[-1].type).with_nullable(
outer_field_content.is_option
outer_field_content._arrow_needs_option_type()
)
)

Expand Down
164 changes: 164 additions & 0 deletions tests/test_2340_unknown_type_to_arrow_and_back.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

# from __future__ import annotations
from __future__ import annotations

import io

import numpy as np
import pytest

import awkward as ak
from awkward.contents import (
EmptyArray,
NumpyArray,
RecordArray,
RegularArray,
UnionArray,
)
from awkward.operations import to_list
from awkward.types import ListType, OptionType, UnknownType

pyarrow = pytest.importorskip("pyarrow")
pq = pytest.importorskip("pyarrow.parquet")


def test_bare_unknown():
unk_array = ak.Array([[], []]) # sub-arrays are type unknown
assert unk_array[0].type.content == UnknownType()

unk_array_arrow = ak.to_arrow(unk_array)
# This array has one field, where we can get at pyarrow storage info
if hasattr(unk_array_arrow.type.storage_type, "field"):
# We support older versions of pyarrow that lack a field method
field = unk_array_arrow.type.storage_type.field(0)
assert field.type.storage_type == pyarrow.null()
assert field.nullable # Must be nullable to be valid in Arrow
array_is_valid_within_parquet(unk_array_arrow)

orig_array = ak.from_arrow(unk_array_arrow)
assert to_list(orig_array) == [[], []]
assert orig_array.type == unk_array.type


def test_option_unknown():
unk_option_array = ak.Array([[None, None], []]) # type is ?unknown
assert unk_option_array.type.content == ListType(OptionType(UnknownType()))

unk_opt_array_arrow = ak.to_arrow(unk_option_array)
if hasattr(unk_opt_array_arrow.type.storage_type, "field"):
field = unk_opt_array_arrow.type.storage_type.field(0)
assert field.type.storage_type == pyarrow.null()
assert field.nullable # Nullable but this time it's because we're ?unknown
array_is_valid_within_parquet(unk_opt_array_arrow)

orig_array = ak.from_arrow(unk_opt_array_arrow)
assert orig_array.type == unk_option_array.type
assert to_list(orig_array) == [[None, None], []]

# This is different for ... reasons:
ua1 = ak.Array([[], [], [], [None]])[0:3]
assert ua1.type.content == ListType(OptionType(UnknownType()))
ua1a = ak.to_arrow(ua1)
if hasattr(ua1a.type.storage_type, "field"):
field = ua1a.type.storage_type.field(0)
assert field.type.storage_type == pyarrow.null()
assert field.nullable # Like above, still nullable
ua1aa = ak.from_arrow(ua1a)
assert ua1aa.type.content == ListType(OptionType(UnknownType()))
assert len(ua1aa) == 3


def test_toplevel_unknown():
unk_array = ak.Array([])
assert unk_array.type.content == UnknownType()

unk_array_arrow = ak.to_arrow(unk_array)
assert len(unk_array_arrow) == 0
# Note: we cannot test the nullability of this arrow array, since it has zero fields
# field = unk_array_arrow.type.storage_type.field(0)
assert unk_array_arrow.type.num_fields == 0 # (just demonstrating the above)
# But we can still give Parquet conversion a shot
array_is_valid_within_parquet(unk_array_arrow)

orig_array = ak.from_arrow(unk_array_arrow)
assert orig_array.type == unk_array.type
assert to_list(orig_array) == []

ua2 = ak.Array([None])[0:0] # This is a top-level Option<EmptyArray>
assert ua2.type.content == OptionType(UnknownType())
ua2a = ak.to_arrow(ua2)
ua2aa = ak.from_arrow(ua2a)
assert ua2aa.type.content == OptionType(UnknownType())


def test_recordarray_with_unknowns():
a = RecordArray([EmptyArray(), NumpyArray([])], ["x", "y"], length=0)
arw = ak.to_arrow(a)
if hasattr(arw.type.storage_type, "field"):
assert arw.type.storage_type.field(0).nullable
array_is_valid_within_parquet(arw)
# This is a strange, laboratory kind of object.
# It seems unlikely to be found in the wild.
# I'm not sure what other tests here would be meaningful.


def test_table_with_unknowns():
a = RecordArray([EmptyArray(), NumpyArray([1, 2])], ["x", "y"])
# Again this is a strange one!
table = ak.to_arrow_table(a)
assert table.field(0).nullable
temp = io.BytesIO()
pq.write_table(table, temp)


def test_regulararray_with_unknown():
a = RegularArray(EmptyArray(), 0)
# RegularArray is helpful.
# But when it's given an EmptyArray to recycle, it becomes absolutely unhelpful.
arw = ak.to_arrow(a)
if hasattr(arw.type.storage_type, "field"):
assert arw.type.storage_type.field(0).nullable
assert to_list(arw) == []
array_is_valid_within_parquet(arw)


def test_unionarray_with_unknown():
# Although a UnionArray with an EmptyArray content type has no application,
# we can still exercise a code path this way.
a = UnionArray(
tags=ak.index.Index8(np.array([1, 1, 1], dtype=np.int8)),
# tags: none of the elements are, or could be, taken from the EmptyArray.
index=ak.index.Index64(np.array([0, 1, 2], dtype=np.int64)),
contents=[
EmptyArray(),
NumpyArray([10, 20, 30]),
],
)
assert to_list(a) == [10, 20, 30]
arw = ak.to_arrow(a)
if hasattr(arw.type.storage_type, "field"):
assert arw.type.storage_type.field(0).nullable
assert not arw.type.storage_type.field(1).nullable
# array_is_valid_within_parquet(arw) # This fails for unrelated reasons.
art = ak.from_arrow(arw)
# round-trip is okay but the UnionArray is lost. Separate issue?
assert to_list(art) == [10, 20, 30]


#### Helper method(s)


def array_is_valid_within_parquet(arrow_array):
"""
Helper function that writes the given array to a Parquet table.
Prior to 2340, this would raise:
pyarrow.lib.ArrowInvalid: NullType Arrow field must be nullable
"""
table = pyarrow.Table.from_arrays([arrow_array], names=["col1"])
table.validate(
full=True
) # Frustratingly, this does *not* produce the anticipated exception, but..
temp = io.BytesIO()
pq.write_table(table, temp) # *this does* perform the validation we need.
assert len(temp.getbuffer()) > 0

0 comments on commit 21af5dc

Please sign in to comment.