Skip to content

Commit

Permalink
feat: allow awkward type arrays filtering based on rdfentry (#2202)
Browse files Browse the repository at this point in the history
* fix: use generic int type for offsets

* feat: allow awkward type arrays filtering based on rdfentry

* fix: pylint fixes

* fix: revert type

* fix: more complex test as suggested by Jim
  • Loading branch information
ianna committed Feb 6, 2023
1 parent b92d3f1 commit 339896f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
31 changes: 30 additions & 1 deletion src/awkward/_connect/rdataframe/from_rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def form_dtype(form):
column_types = {}
result_ptrs = {}
contents = {}
awkward_type_cols = {}

columns = columns + ("rdfentry_",)
maybe_indexed = False

# Important note: This loop is separate from the next one
# in order not to trigger the additional RDataFrame
Expand All @@ -138,6 +142,12 @@ def form_dtype(form):
column_types[col] = data_frame.GetColumnType(col)
result_ptrs[col] = data_frame.Take[column_types[col]](col)

if ROOT.awkward.is_awkward_type[column_types[col]]():
maybe_indexed = True

if not maybe_indexed:
columns = columns[:-1]

for col in columns:
if ROOT.awkward.is_awkward_type[column_types[col]](): # Retrieve Awkward arrays

Expand All @@ -149,7 +159,7 @@ def form_dtype(form):
lookup = result_ptrs[col].begin().lookup()
generator = lookup[col].generator
layout = generator.tolayout(lookup[col], 0, ())
contents[col] = layout
awkward_type_cols[col] = layout

else: # Convert the C++ vectors to Awkward arrays
form_str = ROOT.awkward.type_to_form[column_types[col], offsets_type](0)
Expand Down Expand Up @@ -228,4 +238,23 @@ def form_dtype(form):
form, length, buffers, byteorder=ak._util.native_byteorder
)

if col == "rdfentry_":
contents[col] = ak.index.Index64(
contents[col].layout.to_backend_array(
allow_missing=True, backend=ak._backends.NumpyBackend.instance()
)
)

for key, value in awkward_type_cols.items():
if len(contents["rdfentry_"]) < len(value):
contents[key] = ak._util.wrap(
ak.contents.IndexedArray(contents["rdfentry_"], value),
highlevel=True,
)
else:
contents[key] = value

if maybe_indexed:
del contents["rdfentry_"]

return ak.zip(contents, depth_limit=1)
46 changes: 46 additions & 0 deletions tests/test_2202_filter_multiple_columns_from_rdataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numpy as np # noqa: F401
import pytest

import awkward as ak

ROOT = pytest.importorskip("ROOT")


compiler = ROOT.gInterpreter.Declare


def test_data_frame_filter():
array_x = ak.Array(
[
{"x": [1.1, 1.2, 1.3]},
{"x": [2.1, 2.2]},
{"x": [3.1]},
{"x": [4.1, 4.2, 4.3, 4.4]},
{"x": [5.1]},
]
)
array_y = ak.Array([1, 2, 3, 4, 5])
array_z = ak.Array([[1.1], [2.1, 2.3, 2.4], [3.1], [4.1, 4.2, 4.3], [5.1]])

df = ak.to_rdataframe({"x": array_x, "y": array_y, "z": array_z})

assert str(df.GetColumnType("x")).startswith("awkward::Record_")
assert df.GetColumnType("y") == "int64_t"
assert df.GetColumnType("z") == "ROOT::VecOps::RVec<double>"

df = df.Filter("y % 2 == 0")

out = ak.from_rdataframe(
df,
columns=(
"x",
"y",
"z",
),
)
assert out["x"].tolist() == [{"x": [2.1, 2.2]}, {"x": [4.1, 4.2, 4.3, 4.4]}]
assert out["y"].tolist() == [2, 4]
assert out["z"].tolist() == [[2.1, 2.3, 2.4], [4.1, 4.2, 4.3]]
assert len(out) == 2

0 comments on commit 339896f

Please sign in to comment.