Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 45 additions & 24 deletions tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "ROOT/RDF/ColumnReaderUtils.hxx"
#include "ROOT/RDF/GraphNode.hxx"
#include "ROOT/RDF/RActionBase.hxx"
#include "ROOT/RDF/RFilterBase.hxx"
#include "ROOT/RDF/RJittedFilter.hxx"
#include "ROOT/RDF/RLoopManager.hxx"

#include <cstddef> // std::size_t
Expand All @@ -37,8 +39,12 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
// Template needed to avoid dependency on ActionHelpers.hxx
Helper fHelper;

/// Pointer to the previous node in this branch of the computation graph
std::vector<std::shared_ptr<PrevNode>> fPrevNodes;
// If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RFilterBase type:
// we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
using PrevNodeCommon_t = std::conditional_t<std::is_same_v<PrevNode, ROOT::Detail::RDF::RJittedFilter>,
ROOT::Detail::RDF::RFilterBase, PrevNode>;
/// Previous nodes in the computation graph. First element is nominal, others are varied.
std::vector<std::shared_ptr<PrevNodeCommon_t>> fPrevNodes;

/// Column readers per slot and per input column
std::vector<std::vector<RColumnReaderBase *>> fValues;
Expand All @@ -51,13 +57,46 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {

ROOT::RDF::SampleCallback_t GetSampleCallback() final { return fHelper.GetSampleCallback(); }

void AppendVariedPrevNodes()
{
// This method only makes sense if we're appending the varied filters to the list after the nominal
assert(fPrevNodes.size() == 1);
const auto &currentVariations = GetVariations();

// If this node hangs from the RLoopManager itself, just use that as the upstream node for each variation
auto nominalPrevNode = fPrevNodes.begin();
if (static_cast<ROOT::Detail::RDF::RNodeBase *>(nominalPrevNode->get()) == fLoopManager) {
fPrevNodes.resize(1 + currentVariations.size(), *nominalPrevNode);
return;
}

// Otherwise, append one varied filter per variation
const auto &prevVariations = (*nominalPrevNode)->GetVariations();

fPrevNodes.reserve(1 + prevVariations.size());
// Get valid iterator after resizing
nominalPrevNode = fPrevNodes.begin();

// Need to populate parts of the computation graph for which we have empty shells, e.g. RJittedFilters
if (!currentVariations.empty())
fLoopManager->Jit();
for (const auto &variation : currentVariations) {
if (IsStrInVec(variation, prevVariations)) {
fPrevNodes.emplace_back(
std::static_pointer_cast<PrevNodeCommon_t>((*nominalPrevNode)->GetVariedFilter(variation)));
} else {
fPrevNodes.push_back(*nominalPrevNode);
}
}
}

public:
RActionSnapshot(Helper &&h, const std::vector<std::string> &columns,
const std::vector<const std::type_info *> &colTypeIDs, std::shared_ptr<PrevNode> pd,
const RColumnRegister &colRegister)
: RActionBase(pd->GetLoopManagerUnchecked(), columns, colRegister, pd->GetVariations()),
fHelper(std::move(h)),
fPrevNodes{std::move(pd)},
fPrevNodes{std::static_pointer_cast<PrevNodeCommon_t>(pd)},
fValues(GetNSlots()),
fColTypeIDs(colTypeIDs)
{
Expand All @@ -69,26 +108,7 @@ public:
fIsDefine.push_back(colRegister.IsDefineOrAlias(columns[i]));

if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
if (const auto &variations = GetVariations(); !variations.empty()) {
// Get pointers to previous nodes of all systematics
fPrevNodes.reserve(1 + variations.size());
auto nominalFilter = fPrevNodes.front();
if (static_cast<RNodeBase *>(nominalFilter.get()) == fLoopManager) {
// just fill this with the RLoopManager N times
fPrevNodes.resize(1 + variations.size(), nominalFilter);
} else {
// create varied versions of the previous filter node
const auto &prevVariations = nominalFilter->GetVariations();
for (const auto &variation : variations) {
if (IsStrInVec(variation, prevVariations)) {
fPrevNodes.emplace_back(
std::static_pointer_cast<PrevNode>(nominalFilter->GetVariedFilter(variation)));
} else {
fPrevNodes.emplace_back(nominalFilter);
}
}
}
}
AppendVariedPrevNodes();
}
}

Expand Down Expand Up @@ -251,7 +271,8 @@ public:
std::unique_ptr<RActionBase> CloneAction(void *newResult) final
{
return std::make_unique<RActionSnapshot>(fHelper.CallMakeNew(newResult), GetColumnNames(), fColTypeIDs,
fPrevNodes.front(), GetColRegister());
std::static_pointer_cast<PrevNode>(fPrevNodes.front()),
GetColRegister());
}
};

Expand Down
35 changes: 35 additions & 0 deletions tree/dataframe/test/dataframe_snapshotWithVariations.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,38 @@ TEST(RDFVarySnapshot, SnapshotVirtualClass)
}
}
}

// https://github.com/root-project/root/issues/20320
TEST(RDFVarySnapshot, GH20320)
{
const char *fileName{"dataframe_snapshot_with_variations_regression_gh20330.root"};
RemoveFileRAII(fileName);

ROOT::RDataFrame df{1};

auto df_def = df.Define("val", []() { return 2; });

auto df_var =
df_def.Vary("val", [](int val) { return ROOT::RVecI{val - 1, val + 1}; }, {"val"}, {"down", "up"}, "var");

// Jitted filters used to break the Snapshot because:
// - It did not JIT the RJittedFilter before requesting for the varied filters
// - It did not take into account that the previous nodes of the Snapshot could be of different types
auto df_fil = df_var.Filter("val > 0");

ROOT::RDF::RSnapshotOptions opts;
opts.fIncludeVariations = true;
auto snap = df_fil.Snapshot("tree", fileName, {"val"}, opts);

auto take_val = snap->Take<int>("val");
auto take_var_up = snap->Take<int>("val__var_up");
auto take_var_down = snap->Take<int>("val__var_down");

EXPECT_EQ(take_val->size(), 1);
EXPECT_EQ(take_var_up->size(), 1);
EXPECT_EQ(take_var_down->size(), 1);

EXPECT_EQ(take_var_down->front(), 1);
EXPECT_EQ(take_val->front(), 2);
EXPECT_EQ(take_var_up->front(), 3);
}
Loading