Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions roottest/root/dataframe/testIMT.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ void getTracks(unsigned int mu, FourVectors& tracks) {
// This makes the example stand-alone
void FillTree(const char* filename, const char* treeName) {
if (!gSystem->AccessPathName(filename)) return;
TFile f(filename,"RECREATE");
TTree t(treeName,treeName);
auto f = std::make_unique<TFile>(filename, "RECREATE");
auto t = std::make_unique<TTree>(treeName, treeName);
double b1;
int b2;
std::vector<FourVector> tracks;
std::vector<double> dv {-1,2,3,4};
std::list<int> sl {1,2,3,4};
t.Branch("b1", &b1);
t.Branch("b2", &b2);
t.Branch("tracks", &tracks);
t.Branch("dv", &dv);
t.Branch("sl", &sl);
t->Branch("b1", &b1);
t->Branch("b2", &b2);
t->Branch("tracks", &tracks);
t->Branch("dv", &dv);
t->Branch("sl", &sl);

int nevts = 16000;
for(int i = 0; i < nevts; ++i) {
Expand All @@ -77,11 +77,9 @@ void FillTree(const char* filename, const char* treeName) {

dv.emplace_back(i);
sl.emplace_back(i);
t.Fill();
t->Fill();
}
t.Write();
f.Close();
return;
f->Write();
}

auto fileName = "testIMT.root";
Expand Down
1 change: 1 addition & 0 deletions tree/dataframe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame
ROOT/RDF/RJittedVariation.hxx
ROOT/RDF/RLazyDSImpl.hxx
ROOT/RDF/RLoopManager.hxx
ROOT/RDF/RMaskedEntryRange.hxx
ROOT/RDF/RMergeableValue.hxx
ROOT/RDF/RMetaData.hxx
ROOT/RDF/RNodeBase.hxx
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RCsvDS.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
namespace ROOT::Internal::RDF {
class R__CLING_PTRCHECK(off) RCsvDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
void *fValuePtr;
void *GetImpl(Long64_t) final { return fValuePtr; }
void *GetImpl(std::size_t) final { return fValuePtr; }
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}

public:
RCsvDSColumnReader(void *valuePtr) : fValuePtr(valuePtr) {}
Expand Down
23 changes: 13 additions & 10 deletions tree/dataframe/inc/ROOT/RDF/RAction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,34 @@ public:
}

template <typename ColType>
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType &
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType &
{
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(entry))
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(idx))
return *val;

throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() +
") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
std::to_string(entry) +
std::to_string(idx) +
". You can use the DefaultValueFor operation to provide a default value, or "
"FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
}

template <typename... ColTypes, std::size_t... S>
void CallExec(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>)
void CallExec(unsigned int slot, std::size_t idx, TypeList<ColTypes...>, std::index_sequence<S...>)
{
ROOT::Internal::RDF::CallGuaranteedOrder{[&](auto &&...args) { return fHelper.Exec(slot, args...); },
GetValueChecked<ColTypes>(slot, S, entry)...};
(void)entry; // avoid unused parameter warning (gcc 12.1)
GetValueChecked<ColTypes>(slot, S, idx)...};
(void)idx; // avoid unused parameter warning (gcc 12.1)
}

void Run(unsigned int slot, Long64_t entry) final
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
{
// check if entry passes all filters
if (fPrevNode.CheckFilters(slot, entry))
CallExec(slot, entry, ColumnTypes_t{}, TypeInd_t{});
const auto mask = fPrevNode.CheckFilters(slot, bulkBeginEntry, bulkSize);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });

// Assume 1-size bulk for now
if (mask[0])
CallExec(slot, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{});
}

void TriggerChildrenCount() final { fPrevNode.IncrChildrenCount(); }
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RActionBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public:
RColumnRegister &GetColRegister() { return fColRegister; }
RLoopManager *GetLoopManager() { return fLoopManager; }
unsigned int GetNSlots() const { return fNSlots; }
virtual void Run(unsigned int slot, Long64_t entry) = 0;
virtual void Initialize() = 0;
virtual void InitSlot(TTreeReader *r, unsigned int slot) = 0;
virtual void TriggerChildrenCount() = 0;
Expand Down Expand Up @@ -92,6 +91,8 @@ public:

virtual std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) = 0;
virtual std::unique_ptr<RActionBase> CloneAction(void *newResult) = 0;

virtual void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) = 0;
};
} // namespace RDF
} // namespace Internal
Expand Down
31 changes: 19 additions & 12 deletions tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -166,55 +166,62 @@ public:
fHelper.InitTask(r, slot);
}

void *GetValue(unsigned int slot, std::size_t readerIdx, Long64_t entry)
void *GetValue(unsigned int slot, std::size_t readerIdx, std::size_t idx)
{
assert(slot < fValues.size());
assert(readerIdx < fValues[slot].size());
if (auto *val = fValues[slot][readerIdx]->template TryGet<void>(entry))
if (auto *val = fValues[slot][readerIdx]->template TryGet<void>(idx))
return val;

throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() +
") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
std::to_string(entry) +
std::to_string(idx) +
". You can use the DefaultValueFor operation to provide a default value, or "
"FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
}

void CallExec(unsigned int slot, Long64_t entry)
void CallExec(unsigned int slot, std::size_t idx)
{
std::vector<void *> untypedValues;
auto nReaders = fValues[slot].size();
untypedValues.reserve(nReaders);
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
untypedValues.push_back(GetValue(slot, readerIdx, entry));
untypedValues.push_back(GetValue(slot, readerIdx, idx));

fHelper.Exec(slot, untypedValues);
}

void Run(unsigned int slot, Long64_t entry) final
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
{
if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
// check if entry passes all filters
std::vector<bool> filterPassed(fPrevNodes.size(), false);
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> filterPassed(fPrevNodes.size(), 1ul);
for (unsigned int variation = 0; variation < fPrevNodes.size(); ++variation) {
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, entry);
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, bulkBeginEntry, bulkSize);
}

// Currently, every event where any of nominal or variations pass gets written to the output.
// This logic could be extended for different use cases if the need arises.
if (std::any_of(filterPassed.begin(), filterPassed.end(), [](bool val) { return val; })) {
if (std::any_of(filterPassed.begin(), filterPassed.end(),
[](const ROOT::Internal::RDF::RMaskedEntryRange &val) { return val[0]; })) {
// TODO: Don't allocate
std::vector<void *> untypedValues;
auto nReaders = fValues[slot].size();
untypedValues.reserve(nReaders);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [bulkBeginEntry, bulkSize](auto *v) {
v->Load(
ROOT::Internal::RDF::RMaskedEntryRange{bulkSize, true, static_cast<std::uint64_t>(bulkBeginEntry)});
});
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
untypedValues.push_back(GetValue(slot, readerIdx, entry));
untypedValues.push_back(GetValue(slot, readerIdx, /*idx=*/0u));

fHelper.Exec(slot, untypedValues, filterPassed);
}
} else {
if (fPrevNodes.front()->CheckFilters(slot, entry))
CallExec(slot, entry);
const auto mask = fPrevNodes.front()->CheckFilters(slot, bulkBeginEntry, bulkSize);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
if (mask[0])
CallExec(slot, /*idx=*/0u);
}
}

Expand Down
14 changes: 11 additions & 3 deletions tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define ROOT_INTERNAL_RDF_RCOLUMNREADERBASE

#include <Rtypes.h>
#include <ROOT/RDF/RMaskedEntryRange.hxx>

namespace ROOT {
namespace Detail {
Expand All @@ -26,23 +27,30 @@ This pure virtual class provides a common base class for the different column re
RDSColumnReader.
**/
class R__CLING_PTRCHECK(off) RColumnReaderBase {

public:
virtual ~RColumnReaderBase() = default;

/// Load the column value for the given entry.
/// \param entry The entry number to load.
/// \param mask The entry mask. Values will be loaded only for entries for which the mask equals true.
void Load(const ROOT::Internal::RDF::RMaskedEntryRange &mask) { LoadImpl(mask); }

/// Return the column value for the given entry.
/// \tparam T The column type
/// \param entry The entry number
///
/// The caller is responsible for checking that the returned value actually
/// exists.
template <typename T>
T *TryGet(Long64_t entry)
T *TryGet(std::size_t entryInBulk)
{
return static_cast<T *>(GetImpl(entry));
return static_cast<T *>(GetImpl(entryInBulk));
}

private:
virtual void *GetImpl(Long64_t entry) = 0;
virtual void *GetImpl(std::size_t entryInBulk) = 0;
virtual void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) = 0;
};

} // namespace RDF
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ template <typename T>
class R__CLING_PTRCHECK(off) RDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
T **fDSValuePtr = nullptr;

void *GetImpl(Long64_t) final { return *fDSValuePtr; }
void *GetImpl(std::size_t) final { return *fDSValuePtr; }
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}

public:
RDSColumnReader(void *DSValuePtr) : fDSValuePtr(static_cast<T **>(DSValuePtr)) {}
Expand Down
41 changes: 28 additions & 13 deletions tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,22 @@ template <typename T>
class R__CLING_PTRCHECK(off) RDefaultValueFor final : public RDefineBase {
using ColumnTypes_t = ROOT::TypeTraits::TypeList<T>;
using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
// Avoid instantiating vector<bool> as `operator[]` returns temporaries in that case. Use std::deque instead.
using ValuesPerSlot_t = std::conditional_t<std::is_same<T, bool>::value, std::deque<T>, std::vector<T>>;

using ValuesPerSlot_t = std::vector<ROOT::RVec<T>>;

T fDefaultValue;
ValuesPerSlot_t fLastResults;
// Each slot accesses a cache of values for the current bulk
ValuesPerSlot_t fCachedResultsPerSlot;
// One column reader per slot
std::vector<RColumnReaderBase *> fValues;

/// Define objects corresponding to systematic variations other than nominal for this defined column.
/// The map key is the full variation name, e.g. "pt:up".
std::unordered_map<std::string, std::unique_ptr<RDefineBase>> fVariedDefines;

T &GetValueOrDefault(unsigned int slot, Long64_t entry)
T &GetValueOrDefault(unsigned int slot, std::size_t idx)
{
if (auto *value = fValues[slot]->template TryGet<T>(entry))
if (auto *value = fValues[slot]->template TryGet<T>(idx))
return *value;
else
return fDefaultValue;
Expand All @@ -71,12 +72,16 @@ public:
RLoopManager &lm, const std::string &variationName = "nominal")
: RDefineBase(name, type, colRegister, lm, columns, variationName),
fDefaultValue(defaultValue),
fLastResults(lm.GetNSlots() * RDFInternal::CacheLineStep<T>()),
fCachedResultsPerSlot(lm.GetNSlots() * RDFInternal::CacheLineStep<ROOT::RVec<T>>()),
fValues(lm.GetNSlots())
{
fLoopManager->Register(this);
// We suppress errors that TTreeReader prints regarding the missing branch
fLoopManager->InsertSuppressErrorsForMissingBranch(fColumnNames[0]);
// Assume 1-size bulk for now
for (decltype(lm.GetNSlots()) i = 0; i < lm.GetNSlots(); ++i) {
fCachedResultsPerSlot[i * RDFInternal::CacheLineStep<ROOT::RVec<T>>()].resize(1ul);
}
}

RDefaultValueFor(const RDefaultValueFor &) = delete;
Expand All @@ -97,20 +102,30 @@ public:
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = -1;
}

/// Return the (type-erased) address of the Define'd value for the given processing slot.
/// Return the beginning of the cached results of the current bulk for the input processing slot
void *GetValuePtr(unsigned int slot) final
{
return static_cast<void *>(&fLastResults[slot * RDFInternal::CacheLineStep<T>()]);
return static_cast<void *>(fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()].data());
}

/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
void Update(unsigned int slot, Long64_t entry) final
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
{
if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
// evaluate this define expression, cache the result
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, entry);
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()])
return;

// Assume 1-size bulk for now
fValues[slot]->Load(mask);
const std::size_t bulkSize = fLoopManager->GetCurrentBulkSize();
auto &result = fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()];
result.clear();
result.resize(bulkSize);
for (std::size_t i = 0; i < bulkSize; ++i) {
if (mask[i])
fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()][i] = GetValueOrDefault(slot, i);
}
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
}

void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo & /*id*/) final {}
Expand Down
Loading
Loading