From e483aa0f7ea2d3787ee5496a4ef10332d3215fd6 Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Tue, 9 Apr 2024 15:56:42 +0200 Subject: [PATCH] [PyROOT][RDF] Support conversion of `bool` columns to NumPy arrays The `bool` columns in RDF are special, because the Take action returns a `std::vector`, which has an implementation-depended memory layout for space optimization. Therefore, I suggest supporting taking `bool` columns as `unsigned char` with `Take()`, such that in `RDataFrameAsNumpy` the values can be directly taken as bytes. This avoids superfluous copying in the code, and keeps the special logic in the pythonization side minimal. Closes #8639. --- .../python/ROOT/_pythonization/_rdataframe.py | 4 ++++ .../python/ROOT/_pythonization/_rvec.py | 7 ++++--- .../pyroot/pythonizations/test/rdataframe_asnumpy.py | 12 ++++++++++++ tree/dataframe/src/RDFUtils.cxx | 9 ++++++++- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdataframe.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdataframe.py index 6277c12ba1e60..7dc227e2cab50 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdataframe.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdataframe.py @@ -229,6 +229,10 @@ def RDataFrameAsNumpy(df, columns=None, exclude=None, lazy=False): result_ptrs = {} for column in columns: column_type = df.GetColumnType(column) + # bool columns should be taken as unsigned chars, because NumPy stores + # bools in bytes - different from the std::vector returned by the + # action, which might do some space optimization + column_type = "unsigned char" if column_type == "bool" else column_type result_ptrs[column] = df.Take[column_type](column) result = AsNumpyResult(result_ptrs, columns) diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py index d8de0aced19a6..92f1d082752a3 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rvec.py @@ -71,14 +71,15 @@ _array_interface_dtype_map = { - "float": "f", + "Long64_t": "i", + "ULong64_t": "u", "double": "f", + "float": "f", "int": "i", "long": "i", - "Long64_t": "i", + "unsigned char": "b", "unsigned int": "u", "unsigned long": "u", - "ULong64_t": "u", } diff --git a/bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py b/bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py index 8c0670b2753c2..3c98233dea944 100644 --- a/bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py +++ b/bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py @@ -333,6 +333,18 @@ def test_cloning(self): self.assertSequenceEqual( asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist()) + def test_bool_column(self): + """ + Testing converting bool columns to NumPy arrays. + """ + name = "bool_branch" + n_events = 100 + cut = 50 + df = ROOT.RDataFrame(n_events).Define(name, f"(int)rdfentry_ > {cut}") + arr = df.AsNumpy([name])[name] + ref = np.arange(0, n_events) > cut + self.assertTrue(all(arr == ref)) # test values + self.assertEqual(arr.dtype, ref.dtype) # test type if __name__ == '__main__': unittest.main() diff --git a/tree/dataframe/src/RDFUtils.cxx b/tree/dataframe/src/RDFUtils.cxx index d92e7f6e94fa1..57f73a58216ef 100644 --- a/tree/dataframe/src/RDFUtils.cxx +++ b/tree/dataframe/src/RDFUtils.cxx @@ -384,6 +384,13 @@ unsigned int GetColumnWidth(const std::vector& names, const unsigne void CheckReaderTypeMatches(const std::type_info &colType, const std::type_info &requestedType, const std::string &colName) { + bool explicitlySupported = false; + // We want to explicitly support the reading of bools as unsigned char, as + // this is quite common to circumvent the std::vector specialization. + if (TypeID2TypeName(colType) == "bool" && TypeID2TypeName(requestedType) == "unsigned char") { + explicitlySupported = true; + } + // Here we compare names and not typeinfos since they may come from two different contexts: a compiled // and a jitted one. const auto diffTypes = (0 != std::strcmp(colType.name(), requestedType.name())); @@ -392,7 +399,7 @@ void CheckReaderTypeMatches(const std::type_info &colType, const std::type_info return colTClass && colTClass->InheritsFrom(TClass::GetClass(requestedType)); }; - if (diffTypes && !inheritedType()) { + if (!explicitlySupported && diffTypes && !inheritedType()) { const auto tName = TypeID2TypeName(requestedType); const auto colTypeName = TypeID2TypeName(colType); std::string errMsg = "RDataFrame: type mismatch: column \"" + colName + "\" is being used as ";