diff --git a/tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx b/tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx index 0b0e017169872..831b2c0e163d2 100644 --- a/tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx +++ b/tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx @@ -144,7 +144,7 @@ public: fChunkLoader = std::make_unique>(f_rdf, fNumEntries, fEntries, fChunkSize, fBlockSize, fValidationSplit, fCols, vecSizes, vecPadding, fShuffle, fSetSeed); - fBatchLoader = std::make_unique(fChunkSize, fBatchSize, fNumChunkCols); + fBatchLoader = std::make_unique(fBatchSize, fNumChunkCols); // split the dataset into training and validation sets fChunkLoader->SplitDataset(); diff --git a/tmva/tmva/inc/TMVA/BatchGenerator/RBatchLoader.hxx b/tmva/tmva/inc/TMVA/BatchGenerator/RBatchLoader.hxx index 012473ec712b0..3e78fd6904385 100644 --- a/tmva/tmva/inc/TMVA/BatchGenerator/RBatchLoader.hxx +++ b/tmva/tmva/inc/TMVA/BatchGenerator/RBatchLoader.hxx @@ -27,28 +27,21 @@ #include "TMVA/RTensor.hxx" #include "TMVA/Tools.h" -namespace TMVA { -namespace Experimental { -namespace Internal { +namespace TMVA::Experimental::Internal { -// clang-format off /** \class ROOT::TMVA::Experimental::Internal::RBatchLoader \ingroup tmva \brief Building and loading the batches from loaded chunks in RChunkLoader -In this class the chunks that are loaded into memory (see RChunkLoader) are split into batches used in the ML training which are loaded into a queue. This is done for both the training and validation chunks separatly. +In this class the chunks that are loaded into memory (see RChunkLoader) are split into batches used in the ML training +which are loaded into a queue. This is done for both the training and validation chunks separately. */ class RBatchLoader { private: - // clang-format on - std::size_t fChunkSize; std::size_t fBatchSize; std::size_t fNumColumns; - std::size_t fMaxBatches; - std::size_t fTrainingRemainderRow = 0; - std::size_t fValidationRemainderRow = 0; bool fIsActive = false; @@ -63,7 +56,7 @@ private: std::size_t fNumTrainingBatchQueue; std::size_t fNumValidationBatchQueue; - // current batch that is loaded into memeory + // current batch that is loaded into memory std::unique_ptr> fCurrentBatch; // primary and secondary batches used to create batches from a chunk @@ -74,8 +67,7 @@ private: std::unique_ptr> fSecondaryLeftoverValidationBatch; public: - RBatchLoader(std::size_t chunkSize, std::size_t batchSize, std::size_t numColumns) - : fChunkSize(chunkSize), fBatchSize(batchSize), fNumColumns(numColumns) + RBatchLoader(std::size_t batchSize, std::size_t numColumns) : fBatchSize(batchSize), fNumColumns(numColumns) { fPrimaryLeftoverTrainingBatch = @@ -95,9 +87,6 @@ public: public: void Activate() { - // fTrainingRemainderRow = 0; - // fValidationRemainderRow = 0; - { std::lock_guard lock(fBatchLock); fIsActive = true; @@ -132,7 +121,6 @@ public: return batch; } - /// \brief Loading the training batch from the queue /// \return Training batch TMVA::Experimental::RTensor GetTrainBatch() @@ -220,7 +208,7 @@ public: // copy LeftoverBatch to both fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverTrainingBatch else if (emptySlots < LeftoverBatchSize) { - // copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch + // copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch (*fPrimaryLeftoverTrainingBatch) = (*fPrimaryLeftoverTrainingBatch).Resize({fBatchSize, fNumColumns}); std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (emptySlots * fNumColumns), fPrimaryLeftoverTrainingBatch->GetData() + (PrimaryLeftoverSize * fNumColumns)); @@ -231,18 +219,18 @@ public: std::copy(LeftoverBatch.GetData() + (emptySlots * fNumColumns), LeftoverBatch.GetData() + (LeftoverBatchSize * fNumColumns), fSecondaryLeftoverTrainingBatch->GetData()); - + // add fPrimaryLeftoverTrainingBatch to the batch vector auto copy = std::make_unique>(std::vector{fBatchSize, fNumColumns}); std::copy(fPrimaryLeftoverTrainingBatch->GetData(), fPrimaryLeftoverTrainingBatch->GetData() + (fBatchSize * fNumColumns), copy->GetData()); batches.emplace_back(std::move(copy)); - + // exchange fPrimaryLeftoverTrainingBatch and fSecondaryLeftoverValidationBatch *fPrimaryLeftoverTrainingBatch = *fSecondaryLeftoverTrainingBatch; - - // restet fSecondaryLeftoverValidationBatch + + // reset fSecondaryLeftoverValidationBatch fSecondaryLeftoverValidationBatch = std::make_unique>(std::vector{0, fNumColumns}); } @@ -269,7 +257,7 @@ public: fTrainingBatchQueue.push(std::move(batches[i])); } } - + /// \brief Creating the validation batches from a chunk and adding them to the queue /// \param[in] chunkTensor RTensor with the data from the chunk /// \param[in] lastbatch Check if the batch in the chunk is the last one @@ -359,8 +347,6 @@ public: std::size_t GetNumValidationBatchQueue() { return fValidationBatchQueue.size(); } }; -} // namespace Internal -} // namespace Experimental -} // namespace TMVA +} // namespace TMVA::Experimental::Internal #endif // TMVA_RBATCHLOADER