From 339896fbdf780fba0fb426bed1059d04128da0a9 Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Mon, 6 Feb 2023 17:51:39 +0100 Subject: [PATCH] feat: allow awkward type arrays filtering based on rdfentry (#2202) * fix: use generic int type for offsets * feat: allow awkward type arrays filtering based on rdfentry * fix: pylint fixes * fix: revert type * fix: more complex test as suggested by Jim --- .../_connect/rdataframe/from_rdataframe.py | 31 ++++++++++++- ...filter_multiple_columns_from_rdataframe.py | 46 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 tests/test_2202_filter_multiple_columns_from_rdataframe.py diff --git a/src/awkward/_connect/rdataframe/from_rdataframe.py b/src/awkward/_connect/rdataframe/from_rdataframe.py index 41756d7a97..de6f126278 100644 --- a/src/awkward/_connect/rdataframe/from_rdataframe.py +++ b/src/awkward/_connect/rdataframe/from_rdataframe.py @@ -130,6 +130,10 @@ def form_dtype(form): column_types = {} result_ptrs = {} contents = {} + awkward_type_cols = {} + + columns = columns + ("rdfentry_",) + maybe_indexed = False # Important note: This loop is separate from the next one # in order not to trigger the additional RDataFrame @@ -138,6 +142,12 @@ def form_dtype(form): column_types[col] = data_frame.GetColumnType(col) result_ptrs[col] = data_frame.Take[column_types[col]](col) + if ROOT.awkward.is_awkward_type[column_types[col]](): + maybe_indexed = True + + if not maybe_indexed: + columns = columns[:-1] + for col in columns: if ROOT.awkward.is_awkward_type[column_types[col]](): # Retrieve Awkward arrays @@ -149,7 +159,7 @@ def form_dtype(form): lookup = result_ptrs[col].begin().lookup() generator = lookup[col].generator layout = generator.tolayout(lookup[col], 0, ()) - contents[col] = layout + awkward_type_cols[col] = layout else: # Convert the C++ vectors to Awkward arrays form_str = ROOT.awkward.type_to_form[column_types[col], offsets_type](0) @@ -228,4 +238,23 @@ def form_dtype(form): form, length, buffers, byteorder=ak._util.native_byteorder ) + if col == "rdfentry_": + contents[col] = ak.index.Index64( + contents[col].layout.to_backend_array( + allow_missing=True, backend=ak._backends.NumpyBackend.instance() + ) + ) + + for key, value in awkward_type_cols.items(): + if len(contents["rdfentry_"]) < len(value): + contents[key] = ak._util.wrap( + ak.contents.IndexedArray(contents["rdfentry_"], value), + highlevel=True, + ) + else: + contents[key] = value + + if maybe_indexed: + del contents["rdfentry_"] + return ak.zip(contents, depth_limit=1) diff --git a/tests/test_2202_filter_multiple_columns_from_rdataframe.py b/tests/test_2202_filter_multiple_columns_from_rdataframe.py new file mode 100644 index 0000000000..2c326fe257 --- /dev/null +++ b/tests/test_2202_filter_multiple_columns_from_rdataframe.py @@ -0,0 +1,46 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +import numpy as np # noqa: F401 +import pytest + +import awkward as ak + +ROOT = pytest.importorskip("ROOT") + + +compiler = ROOT.gInterpreter.Declare + + +def test_data_frame_filter(): + array_x = ak.Array( + [ + {"x": [1.1, 1.2, 1.3]}, + {"x": [2.1, 2.2]}, + {"x": [3.1]}, + {"x": [4.1, 4.2, 4.3, 4.4]}, + {"x": [5.1]}, + ] + ) + array_y = ak.Array([1, 2, 3, 4, 5]) + array_z = ak.Array([[1.1], [2.1, 2.3, 2.4], [3.1], [4.1, 4.2, 4.3], [5.1]]) + + df = ak.to_rdataframe({"x": array_x, "y": array_y, "z": array_z}) + + assert str(df.GetColumnType("x")).startswith("awkward::Record_") + assert df.GetColumnType("y") == "int64_t" + assert df.GetColumnType("z") == "ROOT::VecOps::RVec" + + df = df.Filter("y % 2 == 0") + + out = ak.from_rdataframe( + df, + columns=( + "x", + "y", + "z", + ), + ) + assert out["x"].tolist() == [{"x": [2.1, 2.2]}, {"x": [4.1, 4.2, 4.3, 4.4]}] + assert out["y"].tolist() == [2, 4] + assert out["z"].tolist() == [[2.1, 2.3, 2.4], [4.1, 4.2, 4.3]] + assert len(out) == 2