diff --git a/tmva/tmva/inc/TMVA/DNN/Functions.h b/tmva/tmva/inc/TMVA/DNN/Functions.h index b62b56ed05de6..b6456f76d7077 100644 --- a/tmva/tmva/inc/TMVA/DNN/Functions.h +++ b/tmva/tmva/inc/TMVA/DNN/Functions.h @@ -76,6 +76,11 @@ enum class EInitialization { kGlorotUniform = 'F', }; +/// Enum representing the optimizer used for training. +enum class EOptimizer { + kSGD = 0, +}; + //______________________________________________________________________________ // // Activation Functions diff --git a/tmva/tmva/inc/TMVA/DNN/Optimizer.h b/tmva/tmva/inc/TMVA/DNN/Optimizer.h new file mode 100644 index 0000000000000..5db8ae9a843ee --- /dev/null +++ b/tmva/tmva/inc/TMVA/DNN/Optimizer.h @@ -0,0 +1,114 @@ +// @(#)root/tmva/tmva/dnn:$Id$ +// Author: Ravi Kiran S + +/********************************************************************************** + * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * + * Package: TMVA * + * Class : VOptimizer * + * Web : http://tmva.sourceforge.net * + * * + * Description: * + * General Optimizer Class * + * * + * Authors (alphabetical): * + * Ravi Kiran S - CERN, Switzerland * + * * + * Copyright (c) 2005-2018 : * + * CERN, Switzerland * + * U. of Victoria, Canada * + * MPI-K Heidelberg, Germany * + * U. of Bonn, Germany * + * * + * Redistribution and use in source and binary forms, with or without * + * modification, are permitted according to the terms listed in LICENSE * + * (http://tmva.sourceforge.net/LICENSE) * + **********************************************************************************/ + +#ifndef TMVA_DNN_OPTIMIZER +#define TMVA_DNN_OPTIMIZER + +#include "TMVA/DNN/GeneralLayer.h" +#include "TMVA/DNN/DeepNet.h" + +namespace TMVA { +namespace DNN { + +/** \class VOptimizer + Generic Optimizer class + + This class represents the general class for all optimizers in the Deep Learning + Module. + */ +template , + typename DeepNet_t = TDeepNet> +class VOptimizer { +public: + using Matrix_t = typename Architecture_t::Matrix_t; + using Scalar_t = typename Architecture_t::Scalar_t; + +protected: + Scalar_t fLearningRate; ///< The learning rate used for training. + size_t fGlobalStep; ///< The current global step count during training. + DeepNet_t &fDeepNet; ///< The reference to the deep net. + + /*! Update the weights, given the current weight gradients. */ + virtual void + UpdateWeights(size_t layerIndex, std::vector &weights, const std::vector &weightGradients) = 0; + + /*! Update the biases, given the current bias gradients. */ + virtual void + UpdateBiases(size_t layerIndex, std::vector &biases, const std::vector &biasGradients) = 0; + +public: + /*! Constructor. */ + VOptimizer(Scalar_t learningRate, DeepNet_t &deepNet); + + /*! Performs one step of optimization. */ + void Step(); + + /*! Virtual Destructor. */ + virtual ~VOptimizer() = default; + + /*! Increments the global step. */ + void IncrementGlobalStep() { this->fGlobalStep++; } + + /*! Getters */ + Scalar_t GetLearningRate() const { return fLearningRate; } + size_t GetGlobalStep() const { return fGlobalStep; } + std::vector &GetLayers() { return fDeepNet.GetLayers(); } + Layer_t *GetLayerAt(size_t i) { return fDeepNet.GetLayerAt(i); } + + /*! Setters */ + void SetLearningRate(size_t learningRate) { fLearningRate = learningRate; } +}; + +// +// +// The General Optimizer Class - Implementation +//_________________________________________________________________________________________________ +template +VOptimizer::VOptimizer(Scalar_t learningRate, DeepNet_t &deepNet) + : fLearningRate(learningRate), fGlobalStep(0), fDeepNet(deepNet) +{ +} + +// //_________________________________________________________________________________________________ +// template +// VOptimizer::~VOptimizer() +// { +// } + +//_________________________________________________________________________________________________ +template +auto VOptimizer::Step() -> void +{ + for (size_t i = 0; i < this->GetLayers().size(); i++) { + this->UpdateWeights(i, this->GetLayerAt(i)->GetWeights(), this->GetLayerAt(i)->GetWeightGradients()); + this->UpdateBiases(i, this->GetLayerAt(i)->GetBiases(), this->GetLayerAt(i)->GetBiasGradients()); + } +} + +} // namespace DNN +} // namespace TMVA + +#endif \ No newline at end of file diff --git a/tmva/tmva/inc/TMVA/DNN/SGD.h b/tmva/tmva/inc/TMVA/DNN/SGD.h new file mode 100644 index 0000000000000..9e2e56f933301 --- /dev/null +++ b/tmva/tmva/inc/TMVA/DNN/SGD.h @@ -0,0 +1,174 @@ +// @(#)root/tmva/tmva/dnn:$Id$ +// Author: Ravi Kiran S + +/********************************************************************************** + * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * + * Package: TMVA * + * Class : TSGD * + * Web : http://tmva.sourceforge.net * + * * + * Description: * + * Stochastic Batch Gradient Descent Optimizer Class * + * * + * Authors (alphabetical): * + * Ravi Kiran S - CERN, Switzerland * + * * + * Copyright (c) 2005-2018: * + * CERN, Switzerland * + * U. of Victoria, Canada * + * MPI-K Heidelberg, Germany * + * U. of Bonn, Germany * + * * + * Redistribution and use in source and binary forms, with or without * + * modification, are permitted according to the terms listed in LICENSE * + * (http://tmva.sourceforge.net/LICENSE) * + **********************************************************************************/ + +#ifndef TMVA_DNN_SGD +#define TMVA_DNN_SGD + +#include "TMatrix.h" +#include "TMVA/DNN/Optimizer.h" +#include "TMVA/DNN/Functions.h" + +namespace TMVA { +namespace DNN { + +/** \class TSGD + Stochastic Batch Gradient Descent Optimizer class + + This class represents the Stochastic Batch Gradient Descent Optimizer with options for applying momentum + and nesterov momentum. + */ +template , + typename DeepNet_t = TDeepNet> +class TSGD : public VOptimizer { +public: + using Matrix_t = typename Architecture_t::Matrix_t; + using Scalar_t = typename Architecture_t::Scalar_t; + +protected: + Scalar_t fMomentum; ///< The momentum used for training. + std::vector> + fPastWeightGradients; ///< The sum of the past weight gradients associated with the deep net. + std::vector> + fPastBiasGradients; ///< The sum of the past bias gradients associated with the deep net. + + /*! Update the weights, given the current weight gradients. */ + void UpdateWeights(size_t layerIndex, std::vector &weights, const std::vector &weightGradients); + + /*! Update the biases, given the current bias gradients. */ + void UpdateBiases(size_t layerIndex, std::vector &biases, const std::vector &biasGradients); + +public: + /*! Constructor. */ + TSGD(Scalar_t learningRate, DeepNet_t &deepNet, Scalar_t momentum); + + /*! Destructor. */ + ~TSGD() = default; + + /*! Getters */ + Scalar_t GetMomentum() const { return fMomentum; } + + std::vector> &GetPastWeightGradients() { return fPastWeightGradients; } + std::vector &GetPastWeightGradientsAt(size_t i) { return fPastWeightGradients[i]; } + + std::vector> &GetPastBiasGradients() { return fPastBiasGradients; } + std::vector &GetPastBiasGradientsAt(size_t i) { return fPastBiasGradients[i]; } +}; + +// +// +// The Stochastic Gradient Descent Optimizer Class - Implementation +//_________________________________________________________________________________________________ +template +TSGD::TSGD(Scalar_t learningRate, DeepNet_t &deepNet, Scalar_t momentum) + : VOptimizer(learningRate, deepNet), fMomentum(momentum) +{ + std::vector &layers = deepNet.GetLayers(); + size_t layersNSlices = layers.size(); + fPastWeightGradients.resize(layersNSlices); + fPastBiasGradients.resize(layersNSlices); + + for (size_t i = 0; i < layersNSlices; i++) { + size_t weightsNSlices = (layers[i]->GetWeights()).size(); + + for (size_t j = 0; j < weightsNSlices; j++) { + Matrix_t ¤tWeights = layers[i]->GetWeightsAt(j); + size_t weightsNRows = currentWeights.GetNrows(); + size_t weightsNCols = currentWeights.GetNcols(); + + fPastWeightGradients[i].emplace_back(weightsNRows, weightsNCols); + initialize(fPastWeightGradients[i][j], EInitialization::kZero); + } + + size_t biasesNSlices = (layers[i]->GetBiases()).size(); + + for (size_t j = 0; j < biasesNSlices; j++) { + Matrix_t ¤tBiases = layers[i]->GetBiasesAt(j); + size_t biasesNRows = currentBiases.GetNrows(); + size_t biasesNCols = currentBiases.GetNcols(); + + fPastBiasGradients[i].emplace_back(biasesNRows, biasesNCols); + initialize(fPastBiasGradients[i][j], EInitialization::kZero); + } + } +} + +// //_________________________________________________________________________________________________ +// template +// TSGD::~TSGD() +// { +// } + +//_________________________________________________________________________________________________ +template +auto TSGD::UpdateWeights(size_t layerIndex, std::vector &weights, + const std::vector &weightGradients) -> void +{ + // accumulating the current layer past weight gradients to include the current weight gradients. + // Vt = momentum * Vt-1 + currentGradients + std::vector ¤tLayerPastWeightGradients = this->GetPastWeightGradientsAt(layerIndex); + for (size_t k = 0; k < currentLayerPastWeightGradients.size(); k++) { + Matrix_t accumulation(currentLayerPastWeightGradients[k].GetNrows(), + currentLayerPastWeightGradients[k].GetNcols()); + initialize(accumulation, EInitialization::kZero); + Architecture_t::ScaleAdd(accumulation, currentLayerPastWeightGradients[k], this->GetMomentum()); + Architecture_t::ScaleAdd(accumulation, weightGradients[k], 1.0); + Architecture_t::Copy(currentLayerPastWeightGradients[k], accumulation); + } + + // updating the weights. + // theta = theta - learningRate * Vt + for (size_t i = 0; i < weights.size(); i++) { + Architecture_t::ScaleAdd(weights[i], currentLayerPastWeightGradients[i], -this->GetLearningRate()); + } +} + +//_________________________________________________________________________________________________ +template +auto TSGD::UpdateBiases(size_t layerIndex, std::vector &biases, + const std::vector &biasGradients) -> void +{ + // accumulating the current layer past bias gradients to include the current bias gradients. + // Vt = momentum * Vt-1 + currentGradients + std::vector ¤tLayerPastBiasGradients = this->GetPastBiasGradientsAt(layerIndex); + for (size_t k = 0; k < currentLayerPastBiasGradients.size(); k++) { + Matrix_t accumulation(currentLayerPastBiasGradients[k].GetNrows(), currentLayerPastBiasGradients[k].GetNcols()); + initialize(accumulation, EInitialization::kZero); + Architecture_t::ScaleAdd(accumulation, currentLayerPastBiasGradients[k], this->GetMomentum()); + Architecture_t::ScaleAdd(accumulation, biasGradients[k], 1.0); + Architecture_t::Copy(currentLayerPastBiasGradients[k], accumulation); + } + + // updating the biases + // theta = theta - learningRate * Vt + for (size_t i = 0; i < biases.size(); i++) { + Architecture_t::ScaleAdd(biases[i], currentLayerPastBiasGradients[i], -this->GetLearningRate()); + } +} + +} // namespace DNN +} // namespace TMVA + +#endif \ No newline at end of file diff --git a/tmva/tmva/inc/TMVA/MethodDL.h b/tmva/tmva/inc/TMVA/MethodDL.h index 770233101eadd..ce523f3e8e2f6 100644 --- a/tmva/tmva/inc/TMVA/MethodDL.h +++ b/tmva/tmva/inc/TMVA/MethodDL.h @@ -65,6 +65,7 @@ struct TTrainingSettings { size_t convergenceSteps; size_t maxEpochs; DNN::ERegularization regularization; + DNN::EOptimizer optimizer; Double_t learningRate; Double_t momentum; Double_t weightDecay; diff --git a/tmva/tmva/src/MethodDL.cxx b/tmva/tmva/src/MethodDL.cxx index 2bcfb57faa773..16be729605267 100644 --- a/tmva/tmva/src/MethodDL.cxx +++ b/tmva/tmva/src/MethodDL.cxx @@ -1,5 +1,5 @@ - // @(#)root/tmva/tmva/cnn:$Id$Ndl -// Author: Vladimir Ilievski, Saurav Shekhar +// @(#)root/tmva/tmva/cnn:$Id$Ndl +// Author: Vladimir Ilievski, Saurav Shekhar, Ravi Kiran S /********************************************************************************** * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * @@ -13,6 +13,7 @@ * Authors (alphabetical): * * Vladimir Ilievski - CERN, Switzerland * * Saurav Shekhar - ETH Zurich, Switzerland * + * Ravi Kiran S - CERN, Switzerland * * * * Copyright (c) 2005-2015: * * CERN, Switzerland * @@ -38,6 +39,7 @@ #include "TMVA/DNN/TensorDataLoader.h" #include "TMVA/DNN/Functions.h" #include "TMVA/DNN/DLMinimizers.h" +#include "TMVA/DNN/SGD.h" #include "TStopwatch.h" #include @@ -52,6 +54,7 @@ using TMVA::DNN::EActivationFunction; using TMVA::DNN::ELossFunction; using TMVA::DNN::EInitialization; using TMVA::DNN::EOutputFunction; +using TMVA::DNN::EOptimizer; namespace TMVA { @@ -342,6 +345,15 @@ void MethodDL::ProcessOptions() settings.regularization = DNN::ERegularization::kL2; } + TString optimizer = fetchValueTmp(block, "Optimizer", TString("SGD")); + if (optimizer == "SGD") { + settings.optimizer = DNN::EOptimizer::kSGD; + } else { + // Since only one optimizer is implemented, make that as default choice for now if the input string is + // incorrect. + settings.optimizer = DNN::EOptimizer::kSGD; + } + TString strMultithreading = fetchValueTmp(block, "Multithreading", TString("True")); if (strMultithreading.BeginsWith("T")) { @@ -967,7 +979,8 @@ void MethodDL::TrainDeepNet() { using Scalar_t = typename Architecture_t::Scalar_t; - using DeepNet_t = TMVA::DNN::TDeepNet; + using Layer_t = TMVA::DNN::VGeneralLayer; + using DeepNet_t = TMVA::DNN::TDeepNet; using TensorDataLoader_t = TTensorDataLoader; bool debug = Log().GetMinType() == kDEBUG; @@ -1005,6 +1018,7 @@ void MethodDL::TrainDeepNet() ELossFunction J = this->GetLossFunction(); EInitialization I = this->GetWeightInitialization(); ERegularization R = settings.regularization; + EOptimizer O = settings.optimizer; Scalar_t weightDecay = settings.weightDecay; //Batch size should be included in batch layout as well. There are two possibilities: @@ -1092,16 +1106,24 @@ void MethodDL::TrainDeepNet() deepNet.GetBatchDepth(), deepNet.GetBatchHeight(), deepNet.GetBatchWidth(), deepNet.GetOutputWidth(), nThreads); - // Initialize the minimizer - DNN::TDLGradientDescent minimizer(settings.learningRate, settings.convergenceSteps, - settings.testInterval); + // create a pointer to base class VOptimizer + std::unique_ptr> optimizer; + + // initialize the base class pointer with the corresponding derived class object. + switch (O) { + + // Intentional fall-through + case EOptimizer::kSGD: + optimizer = std::unique_ptr>( + new DNN::TSGD(settings.learningRate, deepNet, settings.momentum)); + break; + } // Initialize the vector of batches, one batch for one slave network std::vector> batches{}; bool converged = false; - // count the steps until the convergence - size_t stepCount = 0; + size_t convergenceCount = 0; size_t batchesInEpoch = nTrainingSamples / deepNet.GetBatchSize(); // start measuring @@ -1139,7 +1161,7 @@ void MethodDL::TrainDeepNet() Double_t minTestError = 0; while (!converged) { - stepCount++; + optimizer->IncrementGlobalStep(); trainingData.Shuffle(rng); // execute all epochs @@ -1155,25 +1177,14 @@ void MethodDL::TrainDeepNet() auto my_batch = trainingData.GetTensorBatch(); - - - - // execute one minimization step - // StepMomentum is currently not written for single thread, TODO write it - if (settings.momentum > 0.0) { - //minimizer.StepMomentum(deepNet, nets, batches, settings.momentum); - minimizer.Step(deepNet, my_batch.GetInput(), my_batch.GetOutput(), my_batch.GetWeights()); - } else { - //minimizer.Step(deepNet, nets, batches); - minimizer.Step(deepNet, my_batch.GetInput(), my_batch.GetOutput(), my_batch.GetWeights()); - } - - + // execute one optimization step + deepNet.Forward(my_batch.GetInput(), true); + deepNet.Backward(my_batch.GetInput(), my_batch.GetOutput(), my_batch.GetWeights()); + optimizer->Step(); } //} - - if ((stepCount % minimizer.GetTestInterval()) == 0) { + if ((optimizer->GetGlobalStep() % settings.testInterval) == 0) { std::chrono::time_point t1,t2; @@ -1191,10 +1202,19 @@ void MethodDL::TrainDeepNet() t2 = std::chrono::system_clock::now(); testError /= (Double_t)(nTestSamples / settings.batchSize); + + // checking for convergence + if (testError < minTestError) { + convergenceCount = 0; + } else { + convergenceCount += settings.testInterval; + } + // copy configuration when reached a minimum error if (testError < minTestError ) { // Copy weights from deepNet to fNet - Log() << std::setw(10) << stepCount << " Minimun Test error found - save the configuration " << Endl; + Log() << std::setw(10) << optimizer->GetGlobalStep() + << " Minimum Test error found - save the configuration " << Endl; for (size_t i = 0; i < deepNet.GetDepth(); ++i) { const auto & nLayer = fNet->GetLayerAt(i); const auto & dLayer = deepNet.GetLayerAt(i); @@ -1239,13 +1259,12 @@ void MethodDL::TrainDeepNet() // nGFlops *= deepnet.GetNFlops() * 1e-9; double eventTime = elapsed1.count()/( batchesInEpoch * settings.testInterval * settings.batchSize); - converged = minimizer.HasConverged(testError) || stepCount >= settings.maxEpochs; + converged = + convergenceCount > settings.convergenceSteps || optimizer->GetGlobalStep() >= settings.maxEpochs; - Log() << std::setw(10) << stepCount << " | " << std::setw(12) << trainingError << std::setw(12) << testError - << std::setw(12) << seconds/settings.testInterval - << std::setw(12) << elapsed_testing.count() - << std::setw(12) << 1./eventTime - << std::setw(12) << minimizer.GetConvergenceCount() + Log() << std::setw(10) << optimizer->GetGlobalStep() << " | " << std::setw(12) << trainingError + << std::setw(12) << testError << std::setw(12) << seconds / settings.testInterval << std::setw(12) + << elapsed_testing.count() << std::setw(12) << 1. / eventTime << std::setw(12) << convergenceCount << Endl; if (converged) { @@ -1254,9 +1273,10 @@ void MethodDL::TrainDeepNet() tstart = std::chrono::system_clock::now(); } - //if (stepCount % 10 == 0 || converged) { - if (converged && debug) { - Log() << "Final Deep Net Weights for phase " << trainingPhase << " epoch " << stepCount << Endl; + // if (stepCount % 10 == 0 || converged) { + if (converged && debug) { + Log() << "Final Deep Net Weights for phase " << trainingPhase << " epoch " << optimizer->GetGlobalStep() + << Endl; auto & weights_tensor = deepNet.GetLayerAt(0)->GetWeights(); auto & bias_tensor = deepNet.GetLayerAt(0)->GetBiases(); for (size_t l = 0; l < weights_tensor.size(); ++l) @@ -1264,12 +1284,10 @@ void MethodDL::TrainDeepNet() bias_tensor[0].Print(); } - } trainingPhase++; } // end loop on training Phase - } //////////////////////////////////////////////////////////////////////////////// diff --git a/tmva/tmva/test/DNN/CMakeLists.txt b/tmva/tmva/test/DNN/CMakeLists.txt index 1a77ba11d8de5..8466b3069d21a 100644 --- a/tmva/tmva/test/DNN/CMakeLists.txt +++ b/tmva/tmva/test/DNN/CMakeLists.txt @@ -103,6 +103,11 @@ if ( (BLAS_FOUND OR mathmore) AND imt AND tmva-cpu) LIBRARIES ${Libraries}) ROOT_ADD_TEST(TMVA-DNN-Minimization-Cpu COMMAND testMinimizationCpu) + # DNN - Optimization CPU + ROOT_EXECUTABLE(testOptimizationCpu TestOptimizationCpu.cxx + LIBRARIES ${Libraries}) + ROOT_ADD_TEST(TMVA-DNN-Optimization-Cpu COMMAND testOptimizationCpu) + endif () # DNN - Activation Functions diff --git a/tmva/tmva/test/DNN/CNN/TestMethodDL.h b/tmva/tmva/test/DNN/CNN/TestMethodDL.h index 1cac008458e8b..f69e670f806fd 100644 --- a/tmva/tmva/test/DNN/CNN/TestMethodDL.h +++ b/tmva/tmva/test/DNN/CNN/TestMethodDL.h @@ -105,15 +105,15 @@ void testMethodDL_CNN(TString architectureStr) "DENSE|2|LINEAR"); // Training strategies. - TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1," + TString training0("LearningRate=1e-1,Optimizer=SGD,Momentum=0.9,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.5+0.5+0.5, Multithreading=True"); - TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1," + TString training1("LearningRate=1e-2,Optimizer=SGD,Momentum=0.9,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True"); - TString training2("LearningRate=1e-3,Momentum=0.0,Repetitions=1," + TString training2("LearningRate=1e-3,Optimizer=SGD,Momentum=0.0,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True"); @@ -171,8 +171,7 @@ void testMethodDL_DNN(TString architectureStr) TMVA::Config::Instance(); TFile *input(0); - // TString fname = "/Users/vladimirilievski/Desktop/Vladimir/GSoC/ROOT-CI/common-version/root/tmva/tmva/test/DNN/CNN/" - // "dataset/tmva_class_example.root"; + // TString fname = "tmva_class_example.root"; TString fname = "http://root.cern.ch/files/tmva_class_example.root"; TString fopt = "CACHEREAD"; input = TFile::Open(fname,fopt); @@ -219,15 +218,15 @@ void testMethodDL_DNN(TString architectureStr) TString layoutString("Layout=RESHAPE|1|1|4|FLAT,DENSE|128|TANH,DENSE|128|TANH,DENSE|128|TANH,DENSE|1|LINEAR"); // Training strategies. - TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1," + TString training0("LearningRate=1e-2,Optimizer=SGD,Momentum=0.9,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.5+0.5+0.5, Multithreading=True"); - TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1," + TString training1("LearningRate=1e-2,Optimizer=SGD,Momentum=0.9,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True"); - TString training2("LearningRate=1e-3,Momentum=0.9,Repetitions=1," + TString training2("LearningRate=1e-3,Optimizer=SGD,Momentum=0.9,Repetitions=1," "ConvergenceSteps=20,BatchSize=256,TestRepetitions=10," "WeightDecay=1e-4,Regularization=L2," "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True"); diff --git a/tmva/tmva/test/DNN/TestOptimization.h b/tmva/tmva/test/DNN/TestOptimization.h new file mode 100644 index 0000000000000..9be8343c6a761 --- /dev/null +++ b/tmva/tmva/test/DNN/TestOptimization.h @@ -0,0 +1,301 @@ +// @(#)root/tmva/tmva/dnn:$Id$ +// Author: Ravi Kiran S + +/********************************************************************************** + * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * + * Package: TMVA * + * Class : * + * Web : http://tmva.sourceforge.net * + * * + * Description: * + * Testing Stochastic Batch Gradient Descent Optimizer * + * * + * Authors (alphabetical): * + * Ravi Kiran S - CERN, Switzerland * + * * + * Copyright (c) 2005-2018: * + * CERN, Switzerland * + * U. of Victoria, Canada * + * MPI-K Heidelberg, Germany * + * U. of Bonn, Germany * + * * + * Redistribution and use in source and binary forms, with or without * + * modification, are permitted according to the terms listed in LICENSE * + * (http://tmva.sourceforge.net/LICENSE) * + **********************************************************************************/ + +#ifndef TMVA_TEST_DNN_TEST_OPTIMIZATION_H +#define TMVA_TEST_DNN_TEST_OPTIMIZATION_H + +#include "Utility.h" +#include "TMath.h" +#include "TRandom3.h" +#include "TStopwatch.h" +#include "TFormula.h" +#include "TString.h" + +#include "TMVA/Configurable.h" +#include "TMVA/Tools.h" +#include "TMVA/Types.h" +#include "TMVA/IMethod.h" +#include "TMVA/ClassifierFactory.h" + +#include "TMVA/DNN/SGD.h" +#include "TMVA/DNN/TensorDataLoader.h" + +#include +#include +#include + +using namespace TMVA::DNN; + +/** Train a linear neural network on a randomly generated linear mapping + * from an 8-dimensional input space to a 1-dimensional output space. + * Returns the error of the response of the network to the input containing + * only ones to the 1x8 matrix used to generate the training data. + */ +template +auto testOptimizationSGD(typename Architecture_t::Scalar_t momentum, Bool_t debug) -> typename Architecture_t::Scalar_t +{ + using Matrix_t = typename Architecture_t::Matrix_t; + using Layer_t = VGeneralLayer; + using DeepNet_t = TDeepNet; + using Optimizer_t = TSGD; + using DataLoader_t = TTensorDataLoader; + + size_t nSamples = 256; + size_t nFeatures = 64; + size_t batchSize = 64; + + // Initialize train and test input + // XTrain = (1 x nSamples x nFeatures) + // XTest = (1 x nSamples x nFeatures) + std::vector> XTrain, XTest; + + XTrain.reserve(1); + XTest.reserve(1); + XTrain.emplace_back(nSamples, nFeatures); + XTest.emplace_back(nSamples, nFeatures); + + // Initialize train and test output + // YTrain = (nSamples x nOutput) + // YTest = (nSamples x nOutput) + size_t nOutput = 1; + TMatrixT YTrain(nSamples, nOutput), YTest(nSamples, nOutput); + + // Initialize train and test weights + // WTrain = (nSamples x 1) + // WTest = (nSamples x 1) + TMatrixT WTrain(nSamples, 1), WTest(nSamples, 1); + + // Initialize K + // K = (nFeatures x nOutput) + TMatrixT K(nFeatures, nOutput); + + // Use random K to generate linear mapping + randomMatrix(K); + + randomMatrix(XTrain[0]); + randomMatrix(XTest[0]); + + // Generate the output + // YTrain = XTrain[0] * K + YTrain.Mult(XTrain[0], K); + YTest.Mult(XTest[0], K); + + // Fill-in the batch weights + fillMatrix(WTrain, 1.0); + fillMatrix(WTest, 1.0); + + // Construct the deepNet + size_t inputDepth = 1; + size_t inputHeight = 1; + size_t inputWidth = nFeatures; + + size_t batchDepth = 1; + size_t batchHeight = batchSize; + size_t batchWidth = nFeatures; + + DeepNet_t deepNet(batchSize, inputDepth, inputHeight, inputWidth, batchDepth, batchHeight, batchWidth, + ELossFunction::kMeanSquaredError, EInitialization::kGauss, ERegularization::kNone, 0.0, true); + deepNet.AddDenseLayer(64, EActivationFunction::kIdentity); + deepNet.AddDenseLayer(64, EActivationFunction::kIdentity); + deepNet.AddDenseLayer(1, EActivationFunction::kIdentity); + deepNet.Initialize(); + + if (debug) { + deepNet.Print(); + } + + // Initialize the tensor inputs + size_t nThreads = 1; + TensorInput trainingInput(XTrain, YTrain, WTrain); + TensorInput testInput(XTest, YTest, WTest); + + DataLoader_t trainingData(trainingInput, nSamples, batchSize, batchDepth, batchHeight, batchWidth, nOutput, + nThreads); + DataLoader_t testingData(testInput, nSamples, batchSize, batchDepth, batchHeight, batchWidth, nOutput, nThreads); + + // Initialize the optimizer + Optimizer_t optimizer(0.001, deepNet, momentum); + + // Initialize the variables related to training procedure + bool converged = false; + size_t testInterval = 1; + size_t maxEpochs = 500; + size_t batchesInEpoch = nSamples / deepNet.GetBatchSize(); + size_t convergenceCount = 0; + size_t convergenceSteps = 10; + + if (debug) { + std::string separator(62, '-'); + std::cout << separator << std::endl; + std::cout << std::setw(10) << "Epoch" + << " | " << std::setw(12) << "Train Err." << std::setw(12) << "Test Err." << std::setw(12) + << "t(s)/epoch" << std::setw(12) << "Eval t(s)" << std::setw(12) << "nEvents/s" << std::setw(12) + << "Conv. Steps" << std::endl; + std::cout << separator << std::endl; + } + + // start measuring + std::chrono::time_point tstart, tend; + tstart = std::chrono::system_clock::now(); + + size_t shuffleSeed = 0; + TMVA::RandomGenerator rng(shuffleSeed); + + Double_t minTestError = 0; + while (!converged) { + optimizer.IncrementGlobalStep(); + trainingData.Shuffle(rng); + + // training process + for (size_t i = 0; i < batchesInEpoch; i++) { + auto my_batch = trainingData.GetTensorBatch(); + deepNet.Forward(my_batch.GetInput(), true); + deepNet.Backward(my_batch.GetInput(), my_batch.GetOutput(), my_batch.GetWeights()); + optimizer.Step(); + } + + // calculating the error + if ((optimizer.GetGlobalStep() % testInterval) == 0) { + std::chrono::time_point t1, t2; + t1 = std::chrono::system_clock::now(); + + // compute test error + Double_t testError = 0.0; + for (auto batch : testingData) { + auto inputTensor = batch.GetInput(); + auto outputMatrix = batch.GetOutput(); + auto weights = batch.GetWeights(); + testError += deepNet.Loss(inputTensor, outputMatrix, weights); + } + testError /= (Double_t)(nSamples / batchSize); + + t2 = std::chrono::system_clock::now(); + + // checking for convergence + if (testError < minTestError) { + convergenceCount = 0; + } else { + convergenceCount += testInterval; + } + + // found the minimum test error + if (testError < minTestError) { + if (debug) { + std::cout << std::setw(10) << optimizer.GetGlobalStep() << " Minimum Test error found : " << testError + << std::endl; + } + minTestError = testError; + } else if (minTestError <= 0.0) + minTestError = testError; + + // compute training error + Double_t trainingError = 0.0; + for (auto batch : trainingData) { + auto inputTensor = batch.GetInput(); + auto outputMatrix = batch.GetOutput(); + auto weights = batch.GetWeights(); + trainingError += deepNet.Loss(inputTensor, outputMatrix, weights); + } + trainingError /= (Double_t)(nSamples / batchSize); + + // stop measuring + tend = std::chrono::system_clock::now(); + + // compute numerical throughput + std::chrono::duration elapsed_seconds = tend - tstart; + std::chrono::duration elapsed1 = t1 - tstart; + + // time to compute training and test errors + std::chrono::duration elapsed_testing = tend - t1; + + double seconds = elapsed_seconds.count(); + double eventTime = elapsed1.count() / (batchesInEpoch * testInterval * batchSize); + + converged = optimizer.GetGlobalStep() >= maxEpochs || convergenceCount > convergenceSteps; + + if (debug) { + std::cout << std::setw(10) << optimizer.GetGlobalStep() << " | " << std::setw(12) << trainingError + << std::setw(12) << testError << std::setw(12) << seconds / testInterval << std::setw(12) + << elapsed_testing.count() << std::setw(12) << 1. / eventTime << std::setw(12) << convergenceCount + << std::endl; + + if (converged) { + std::cout << std::endl; + } + } + + tstart = std::chrono::system_clock::now(); + } + + if (converged && debug) { + std::cout << "Final Deep Net Weights for epoch " << optimizer.GetGlobalStep() << std::endl; + auto &weights_tensor = deepNet.GetLayerAt(0)->GetWeights(); + auto &bias_tensor = deepNet.GetLayerAt(0)->GetBiases(); + for (size_t l = 0; l < weights_tensor.size(); l++) + weights_tensor[l].Print(); + bias_tensor[0].Print(); + } + } + + // test the net + // Logic : Y = X * K + // Let X = I, Then Y = I * K => Y = K + // I = (1 x batchSize x nFeatures) + std::vector I; + I.reserve(1); + I.emplace_back(batchSize, nFeatures); + for (size_t i = 0; i < batchSize; i++) { + I[0](i, i) = 1.0; + } + + deepNet.Forward(I, false); + + // get the output of the last layer of the deepNet + TMatrixT Y(deepNet.GetLayerAt(deepNet.GetLayers().size() - 1)->GetOutputAt(0)); + + if (debug) { + std::cout << "\nY:\n"; + + for (auto i = 0; i < Y.GetNrows(); i++) { + for (auto j = 0; j < Y.GetNcols(); j++) { + std::cout << Y(i, j) << " "; + } + std::cout << std::endl; + } + + std::cout << "\nK:\n"; + for (auto i = 0; i < K.GetNrows(); i++) { + for (auto j = 0; j < K.GetNcols(); j++) { + std::cout << K(i, j) << " "; + } + std::cout << std::endl; + } + } + + return maximumRelativeError(Y, K); +} + +#endif diff --git a/tmva/tmva/test/DNN/TestOptimizationCpu.cxx b/tmva/tmva/test/DNN/TestOptimizationCpu.cxx new file mode 100644 index 0000000000000..089c06dd0c807 --- /dev/null +++ b/tmva/tmva/test/DNN/TestOptimizationCpu.cxx @@ -0,0 +1,69 @@ +// @(#)root/tmva/tmva/dnn:$Id$ +// Author: Ravi Kiran S + +/********************************************************************************** + * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * + * Package: TMVA * + * Class : * + * Web : http://tmva.sourceforge.net * + * * + * Description: * + * Testing Stochastic Batch Gradient Descent Optimizer for Cpu Backend * + * * + * Authors (alphabetical): * + * Ravi Kiran S - CERN, Switzerland * + * * + * Copyright (c) 2005-2018: * + * CERN, Switzerland * + * U. of Victoria, Canada * + * MPI-K Heidelberg, Germany * + * U. of Bonn, Germany * + * * + * Redistribution and use in source and binary forms, with or without * + * modification, are permitted according to the terms listed in LICENSE * + * (http://tmva.sourceforge.net/LICENSE) * + **********************************************************************************/ + +#include "TestOptimization.h" +#include "TMVA/DNN/Architectures/Cpu.h" + +#include + +using namespace TMVA::DNN; + +int main() +{ + std::cout << "Testing optimization: (single precision)" << std::endl; + + Real_t momentumSinglePrecision = 0.0; + Double_t error = testOptimizationSGD>(momentumSinglePrecision, false); + std::cout << "Stochastic Gradient Descent: Maximum relative error = " << error << std::endl; + if (error > 1e-3) { + return 1; + } + + momentumSinglePrecision = 0.9; + error = testOptimizationSGD>(momentumSinglePrecision, false); + std::cout << "Stochastic Gradient Descent with momentum: Maximum relative error = " << error << std::endl; + if (error > 1e-3) { + return 1; + } + + std::cout << std::endl << "Testing optimization: (double precision)" << std::endl; + + Double_t momentumDoublePrecision = 0.0; + error = testOptimizationSGD>(momentumDoublePrecision, false); + std::cout << "Stochastic Gradient Descent: Maximum relative error = " << error << std::endl; + if (error > 1e-5) { + return 1; + } + + momentumDoublePrecision = 0.9; + error = testOptimizationSGD>(momentumDoublePrecision, false); + std::cout << "Stochastic Gradient Descent with momentum: Maximum relative error = " << error << std::endl; + if (error > 1e-5) { + return 1; + } + + return 0; +} \ No newline at end of file