diff --git a/tree/dataframe/inc/ROOT/RDF/SnapshotHelpers.hxx b/tree/dataframe/inc/ROOT/RDF/SnapshotHelpers.hxx index 29edd09c52210..781f6af0d141c 100644 --- a/tree/dataframe/inc/ROOT/RDF/SnapshotHelpers.hxx +++ b/tree/dataframe/inc/ROOT/RDF/SnapshotHelpers.hxx @@ -123,7 +123,9 @@ struct RBranchData { TBranch *fOutputBranch = nullptr; void *fBranchAddressForCArrays = nullptr; // Used to detect if branch addresses need to be updated - int fVariationIndex = -1; // For branches that are only valid if a specific filter passed + // A negative index indicates no variations, 0 is for nominal, >0 marks columns that are only valid if a specific + // filter passed + int fVariationIndex = -1; std::variant fTypeData = FundamentalType{0}; bool fIsCArray = false; bool fIsDefine = false; diff --git a/tree/dataframe/src/RDFSnapshotHelpers.cxx b/tree/dataframe/src/RDFSnapshotHelpers.cxx index 1dcf7aae30aa4..5477c9289df55 100644 --- a/tree/dataframe/src/RDFSnapshotHelpers.cxx +++ b/tree/dataframe/src/RDFSnapshotHelpers.cxx @@ -1173,7 +1173,10 @@ void ROOT::Internal::RDF::SnapshotHelperWithVariations::RegisterVariedColumn(uns std::string const &variationName) { if (columnIndex == originalColumnIndex) { - fBranchData[columnIndex].fVariationIndex = variationIndex; // The base column has variations + // This is a nominal column, but it participates in variations. + // It always needs to be written, but we still need to create a mask bit to mark when nominal is invalid. + assert(variationIndex == 0); + fBranchData[columnIndex].fVariationIndex = 0; fOutputHandle->RegisterBranch(fBranchData[columnIndex].fOutputBranchName, variationIndex); } else if (columnIndex >= fBranchData.size()) { // First task, need to create branches @@ -1218,15 +1221,20 @@ void ROOT::Internal::RDF::SnapshotHelperWithVariations::Exec(unsigned int /*slot for (std::size_t i = 0; i < values.size(); i++) { const auto variationIndex = fBranchData[i].fVariationIndex; if (variationIndex < 0) { - // Branch without variations + // Branch without variations, it always needs to be written SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]); - } else if (filterPassed[variationIndex]) { - // Branch with variations - const bool fundamentalType = fBranchData[i].WriteValueIfFundamental(values[i]); - if (!fundamentalType) { - SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]); + } else { + // Nominal will always be written, systematics only if needed + if (variationIndex == 0 || filterPassed[variationIndex]) { + const bool fundamentalType = fBranchData[i].WriteValueIfFundamental(values[i]); + if (!fundamentalType) { + SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]); + } + } + + if (filterPassed[variationIndex]) { + fOutputHandle->SetMaskBit(variationIndex); } - fOutputHandle->SetMaskBit(variationIndex); } } diff --git a/tree/dataframe/test/dataframe_snapshotWithVariations.cxx b/tree/dataframe/test/dataframe_snapshotWithVariations.cxx index f7aa7e3088e20..005bbf9fccd16 100644 --- a/tree/dataframe/test/dataframe_snapshotWithVariations.cxx +++ b/tree/dataframe/test/dataframe_snapshotWithVariations.cxx @@ -12,7 +12,8 @@ #include #include -#include +#include +#include #include constexpr bool verbose = false; @@ -53,7 +54,8 @@ void checkOutput(TTree &tree, std::vector const &systematics, F &&a ASSERT_GT(tree.GetEntry(i), 0); EXPECT_EQ(x, -1 * y); - if (!activeCuts(x, y)) { + if (!activeCuts(x, y) && !sysName.empty()) { + // Branches with systematics should be zeroed when cuts don't pass EXPECT_EQ(x, X_t{}); EXPECT_EQ(y, Y_t{}); } @@ -98,7 +100,7 @@ TEST(RDFVarySnapshot, SimpleRDFWithFilters) for (const auto branchName : {"x", "y", "x__xVar_0", "x__xVar_1", "y__xVar_0", "y__xVar_0"}) EXPECT_NE(tree->FindBranch(branchName), nullptr) << branchName; - checkOutput(*tree, std::vector{"__xVar_0", "__xVar_1"}, cuts); + checkOutput(*tree, std::vector{"", "__xVar_0", "__xVar_1"}, cuts); if (HasFailure()) { tree->Print(); @@ -264,50 +266,60 @@ TEST(RDFVarySnapshot, Bitmask) if (HasFailure()) break; } + } - // Test that the Masked column reader works - { - SCOPED_TRACE("Usage of bitmask in RDF"); - auto rdf = ROOT::RDataFrame(treename, filename); - - auto filterAvailable = rdf.FilterAvailable("x"); - auto meanX = filterAvailable.Mean("x"); - auto meanY = filterAvailable.Mean("y"); - auto count = filterAvailable.Count(); - - EXPECT_EQ(count.GetValue(), 3ull); // 0, 6, 12 - EXPECT_EQ(meanX.GetValue(), 6.); - EXPECT_EQ(meanY.GetValue(), -6.); - - // Test reading invalid columns - auto mean = rdf.Mean("x"); - EXPECT_THROW(mean.GetValue(), std::out_of_range); - - for (unsigned int systematicIndex : {0, 1, 100}) { - const std::string systematic = "__xVar_" + std::to_string(systematicIndex); - auto filterAv = rdf.FilterAvailable("x" + systematic); - auto meanX_sys = filterAv.Mean("x" + systematic); - auto meanY_sys = filterAv.Mean("y" + systematic); - auto count_sys = filterAv.Count(); - - std::vector expect(N); - std::iota(expect.begin(), expect.end(), systematicIndex); - - const auto nVal = std::count_if(expect.begin(), expect.end(), [](int v) { return v % 6 == 0; }); - // gcc8.5 on alma8 doesn't support transform_reduce, nor reduce - // const int sum = std::transform_reduce(expect.begin(), expect.end(), 0, std::plus<>(), - // [](int v) { return v % 6 == 0 ? v : 0; }); - std::transform(expect.begin(), expect.end(), expect.begin(), [](int v) { return v % 6 == 0 ? v : 0; }); - const int sum = std::accumulate(expect.begin(), expect.end(), 0); - - ASSERT_EQ(count_sys.GetValue(), nVal) << "systematic: " << systematic; - EXPECT_EQ(meanX_sys.GetValue(), sum / nVal) << "systematic: " << systematic; - EXPECT_EQ(meanY_sys.GetValue(), -1. * sum / nVal) << "systematic: " << systematic; - } + // Test that the Masked column reader works + { + SCOPED_TRACE("Usage of bitmask in RDF"); + auto rdf = ROOT::RDataFrame(treename, filename); + + auto filterAvailable = rdf.FilterAvailable("x"); + auto meanX = filterAvailable.Mean("x"); + auto meanY = filterAvailable.Mean("y"); + auto count = filterAvailable.Count(); + + EXPECT_EQ(count.GetValue(), 3ull); // 0, 6, 12 + EXPECT_EQ(meanX.GetValue(), 6.); + EXPECT_EQ(meanY.GetValue(), -6.); + + // Test reading invalid columns + auto mean = rdf.Mean("x"); + EXPECT_THROW(mean.GetValue(), std::out_of_range); + + for (unsigned int systematicIndex : {0, 1, 100}) { + const std::string systematic = "__xVar_" + std::to_string(systematicIndex); + auto filterAv = rdf.FilterAvailable("x" + systematic); + auto meanX_sys = filterAv.Mean("x" + systematic); + auto meanY_sys = filterAv.Mean("y" + systematic); + auto count_sys = filterAv.Count(); + + std::vector expect(N); + std::iota(expect.begin(), expect.end(), systematicIndex); + + const auto nVal = std::count_if(expect.begin(), expect.end(), [](int v) { return v % 6 == 0; }); + // gcc8.5 on alma8 doesn't support transform_reduce, nor reduce + // const int sum = std::transform_reduce(expect.begin(), expect.end(), 0, std::plus<>(), + // [](int v) { return v % 6 == 0 ? v : 0; }); + std::transform(expect.begin(), expect.end(), expect.begin(), [](int v) { return v % 6 == 0 ? v : 0; }); + const int sum = std::accumulate(expect.begin(), expect.end(), 0); + + ASSERT_EQ(count_sys.GetValue(), nVal) << "systematic: " << systematic; + EXPECT_EQ(meanX_sys.GetValue(), sum / nVal) << "systematic: " << systematic; + EXPECT_EQ(meanY_sys.GetValue(), -1. * sum / nVal) << "systematic: " << systematic; } + } - if (HasFailure()) { - tree->Scan("entry:x:y:x__xVar_0:y__xVar_0:x__xVar_1:y__xVar_1:x__xVar_2:y__xVar_2"); + if (HasFailure()) { + auto file = std::make_unique(filename, "READ"); + std::unique_ptr tree{file->Get(treename.data())}; + ASSERT_NE(tree, nullptr); + tree->Scan("entry:R_rdf_mask_testTree_0:x:y:x__xVar_0:y__xVar_0:x__xVar_1:y__xVar_1:x__xVar_2:y__xVar_2"); + + std::unique_ptr>> map{ + file->Get>>( + ("R_rdf_column_to_bitmask_mapping_" + treename).c_str())}; + for (auto const &[name, mapping] : *map) { + std::cout << std::setw(20) << name << " --> " << mapping.first << " " << mapping.second << "\n"; } } } @@ -422,15 +434,16 @@ TEST(RDFVarySnapshot, SnapshotCollections) : ((systematicName.find("xVariation_0") != std::string::npos) ? entry + 1 : entry * 3); const bool passCuts = (originalX % 2 == 0) || originalX == 5; - if (passCuts) + if (passCuts || systematicName.empty()) { EXPECT_EQ(x, originalX) << "sys:'" << systematicName << "' originalX: " << originalX << " event: " << event; - else + ASSERT_EQ(y->size(), 4) << "sys:'" << systematicName << "' entry: " << entry << " originalX: " << originalX + << " event: " << event; + for (unsigned int i = 0; i < y->size(); ++i) { + EXPECT_EQ((*y)[i], x + i); + } + } else { EXPECT_EQ(x, 0) << "sys:'" << systematicName << "' originalX: " << originalX << " event: " << event; - - ASSERT_EQ(y->size(), passCuts ? 4 : 0) - << "sys:'" << systematicName << "' entry: " << entry << " originalX: " << originalX << " event: " << event; - for (unsigned int i = 0; i < y->size(); ++i) { - EXPECT_EQ((*y)[i], x + i); + ASSERT_EQ(y->size(), 0); } } tree->ResetBranchAddresses(); @@ -555,7 +568,8 @@ TEST(RDFVarySnapshot, IncludeDependentColumns_JIT) TFile file(fileName); auto tree = file.Get("Events"); ASSERT_NE(tree, nullptr); - tree->Scan(); + if (verbose) + tree->Scan(); double Muon_pt, Muon_pt_up, Muon_pt_down; double Muon_2pt, Muon_2pt_up, Muon_2pt_down; @@ -577,3 +591,76 @@ TEST(RDFVarySnapshot, IncludeDependentColumns_JIT) EXPECT_EQ(2. * Muon_2pt, Muon_2pt_up); } } + +TEST(RDFVarySnapshot, TwoVaryExpressions) +{ + constexpr auto filename = "VarySnapshot_TwoVaryExpressions.root"; + RemoveFileRAII(filename); + constexpr unsigned int N = 10; + ROOT::RDF::RSnapshotOptions options; + options.fOverwriteIfExists = true; + options.fIncludeVariations = true; + + auto cuts = [](float x, float y) { return (x < 20 || x > 30) && (y < 600 || y > 700); }; + + auto snap = ROOT::RDataFrame(N) + .Define("x", [](ULong64_t e) -> float { return 10.f * e; }, {"rdfentry_"}) + .Define("y", [](ULong64_t e) -> float { return 100.f * e; }, {"rdfentry_"}) + .Vary( + "x", [](float x) { return ROOT::RVecF{x - 0.5f, x + 0.5f}; }, {"x"}, {"down", "up"}, "xVar") + .Vary( + "y", [](float y) { return ROOT::RVecF{y - 0.5f, y + 0.5f}; }, {"y"}, {"down", "up"}, "yVar") + .Filter(cuts, {"x", "y"}) + .Snapshot("t", filename, {"x", "y"}, options); + + { + std::unique_ptr file{TFile::Open(filename)}; + auto tree = file->Get("t"); + + EXPECT_EQ(tree->GetEntries(), 10); + EXPECT_EQ(tree->GetNbranches(), 7); // 6 branches for x/y with variations, bitmask + for (const auto branchName : {"x", "y", "x__xVar_down", "x__xVar_up", "y__yVar_down", "y__yVar_up"}) + EXPECT_NE(tree->GetBranch(branchName), nullptr) << branchName; + + for (std::string combos : {"x:y", "x__xVar_down:y", "x__xVar_up:y", "x:y__yVar_down", "x:y__yVar_up"}) { + const auto xName = combos.substr(0, combos.find(':')); + const auto yName = combos.substr(combos.find(':') + 1); + + float x; + float y; + ASSERT_EQ(TTree::kMatch, tree->SetBranchAddress(xName.c_str(), &x)) << xName; + ASSERT_EQ(TTree::kMatch, tree->SetBranchAddress(yName.c_str(), &y)) << yName; + + for (unsigned int i = 0; i < tree->GetEntries(); ++i) { + ASSERT_GT(tree->GetEntry(i), 0); + const float expectedX = i * 10.f - (xName.find("xVar_down") != std::string::npos) * 0.5f + + (xName.find("xVar_up") != std::string::npos) * 0.5f; + const float expectedY = i * 100.f - (yName.find("yVar_down") != std::string::npos) * 0.5f + + (yName.find("yVar_up") != std::string::npos) * 0.5f; + + if (cuts(expectedX, expectedY)) { + EXPECT_EQ(x, expectedX) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "=" + << expectedY << ")"; + EXPECT_EQ(y, expectedY) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "=" + << expectedY << ")"; + EXPECT_TRUE(cuts(x, y)) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "=" + << expectedY << ")"; + } + } + + // Unbind the branches from the stack-local variables + tree->ResetBranchAddresses(); + if (HasFailure()) + break; + } + + if (verbose || HasFailure()) { + auto map = file->Get>>( + "R_rdf_column_to_bitmask_mapping_t"); + for (auto const &[name, mapping] : *map) { + std::cout << std::setw(20) << name << " --> " << mapping.first << " " << mapping.second << "\n"; + } + printTree(*tree); + } + } +}