Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public:
fChunkLoader =
std::make_unique<RChunkLoader<Args...>>(f_rdf, fNumEntries, fEntries, fChunkSize, fBlockSize, fValidationSplit,
fCols, vecSizes, vecPadding, fShuffle, fSetSeed);
fBatchLoader = std::make_unique<RBatchLoader>(fChunkSize, fBatchSize, fNumChunkCols);
fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);

// split the dataset into training and validation sets
fChunkLoader->SplitDataset();
Expand Down
38 changes: 12 additions & 26 deletions tmva/tmva/inc/TMVA/BatchGenerator/RBatchLoader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,21 @@
#include "TMVA/RTensor.hxx"
#include "TMVA/Tools.h"

namespace TMVA {
namespace Experimental {
namespace Internal {
namespace TMVA::Experimental::Internal {

// clang-format off
/**
\class ROOT::TMVA::Experimental::Internal::RBatchLoader
\ingroup tmva
\brief Building and loading the batches from loaded chunks in RChunkLoader
In this class the chunks that are loaded into memory (see RChunkLoader) are split into batches used in the ML training which are loaded into a queue. This is done for both the training and validation chunks separatly.
In this class the chunks that are loaded into memory (see RChunkLoader) are split into batches used in the ML training
which are loaded into a queue. This is done for both the training and validation chunks separately.
*/

class RBatchLoader {
private:
// clang-format on
std::size_t fChunkSize;
std::size_t fBatchSize;
std::size_t fNumColumns;
std::size_t fMaxBatches;
std::size_t fTrainingRemainderRow = 0;
std::size_t fValidationRemainderRow = 0;

bool fIsActive = false;

Expand All @@ -63,7 +56,7 @@ private:
std::size_t fNumTrainingBatchQueue;
std::size_t fNumValidationBatchQueue;

// current batch that is loaded into memeory
// current batch that is loaded into memory
std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;

// primary and secondary batches used to create batches from a chunk
Expand All @@ -74,8 +67,7 @@ private:
std::unique_ptr<TMVA::Experimental::RTensor<float>> fSecondaryLeftoverValidationBatch;

public:
RBatchLoader(std::size_t chunkSize, std::size_t batchSize, std::size_t numColumns)
: fChunkSize(chunkSize), fBatchSize(batchSize), fNumColumns(numColumns)
RBatchLoader(std::size_t batchSize, std::size_t numColumns) : fBatchSize(batchSize), fNumColumns(numColumns)
{

fPrimaryLeftoverTrainingBatch =
Expand All @@ -95,9 +87,6 @@ public:
public:
void Activate()
{
// fTrainingRemainderRow = 0;
// fValidationRemainderRow = 0;

{
std::lock_guard<std::mutex> lock(fBatchLock);
fIsActive = true;
Expand Down Expand Up @@ -132,7 +121,6 @@ public:
return batch;
}


/// \brief Loading the training batch from the queue
/// \return Training batch
TMVA::Experimental::RTensor<float> GetTrainBatch()
Expand Down Expand Up @@ -220,7 +208,7 @@ public:

// copy LeftoverBatch to both fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverTrainingBatch
else if (emptySlots < LeftoverBatchSize) {
// copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
// copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
(*fPrimaryLeftoverTrainingBatch) = (*fPrimaryLeftoverTrainingBatch).Resize({fBatchSize, fNumColumns});
std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (emptySlots * fNumColumns),
fPrimaryLeftoverTrainingBatch->GetData() + (PrimaryLeftoverSize * fNumColumns));
Expand All @@ -231,18 +219,18 @@ public:
std::copy(LeftoverBatch.GetData() + (emptySlots * fNumColumns),
LeftoverBatch.GetData() + (LeftoverBatchSize * fNumColumns),
fSecondaryLeftoverTrainingBatch->GetData());

// add fPrimaryLeftoverTrainingBatch to the batch vector
auto copy =
std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fBatchSize, fNumColumns});
std::copy(fPrimaryLeftoverTrainingBatch->GetData(),
fPrimaryLeftoverTrainingBatch->GetData() + (fBatchSize * fNumColumns), copy->GetData());
batches.emplace_back(std::move(copy));

// exchange fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverValidationBatch
*fPrimaryLeftoverTrainingBatch = *fSecondaryLeftoverTrainingBatch;
// restet fSecondaryLeftoverValidationBatch

// reset fSecondaryLeftoverValidationBatch
fSecondaryLeftoverValidationBatch =
std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{0, fNumColumns});
}
Expand All @@ -269,7 +257,7 @@ public:
fTrainingBatchQueue.push(std::move(batches[i]));
}
}

/// \brief Creating the validation batches from a chunk and adding them to the queue
/// \param[in] chunkTensor RTensor with the data from the chunk
/// \param[in] lastbatch Check if the batch in the chunk is the last one
Expand Down Expand Up @@ -359,8 +347,6 @@ public:
std::size_t GetNumValidationBatchQueue() { return fValidationBatchQueue.size(); }
};

} // namespace Internal
} // namespace Experimental
} // namespace TMVA
} // namespace TMVA::Experimental::Internal

#endif // TMVA_RBATCHLOADER
Loading