Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First GPU implementation of DNNs. #199

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7101abb
Added GPU implementation of DNNs.
Jul 17, 2016
49906d9
Removed space before variable.
Jul 17, 2016
d4b6866
Removed Print() statement on weight matrices after training.
Jul 17, 2016
5690a9c
Removed output.
Jul 17, 2016
578e317
Removed explicit setting of the backend compiler.
Jul 18, 2016
a09f2fa
Added include guards for Cuda architecture header.
Jul 18, 2016
0146c55
Added missing test file.
Jul 18, 2016
f8d2317
Added missing file.
Jul 18, 2016
fb77aa8
Removed profiling switch in TestDerivativesCuda.
Jul 19, 2016
7f559f7
Fixed naming of Cuda kernels.
Jul 19, 2016
66587d0
Some optimzations in the training routine.
Jul 21, 2016
9d162bf
Applied stash.
Jul 21, 2016
ab09b0d
Fixed out of bounds memory access in vertical reduction kernel.
Jul 25, 2016
56d0579
Fixed out of bounds memory access in vertical reduction kernel.
Jul 25, 2016
92be232
Merge branch 'tmva_gpu' of https://github.com/simonpf/root into tmva_gpu
Jul 25, 2016
9eb6a50
Minor cosmetics.
Jul 26, 2016
c6ae8ed
Some more cosmetics.
Jul 26, 2016
ab85e89
Cleaned up output.
Jul 28, 2016
4a5822a
Merge branch 'tmva_gpu' of https://github.com/simonpf/root into tmva_gpu
Jul 28, 2016
6c8abba
Fixed minimization test.
Jul 28, 2016
e594dda
Fixed formatting in CudaMatrix.h
Jul 28, 2016
0c7667f
Enlarged batch size in minimization test.
Jul 28, 2016
f9b95e9
:Merge branch 'tmva_gpu' of github.com:simonpf/root into tmva_gpu
Jul 28, 2016
46ac988
Generic data loader.
Aug 9, 2016
8e1edd2
Added TestDataLoaderCuda.cxx.
Aug 9, 2016
dfe0f09
Made copy async.
Aug 9, 2016
b9528e5
Smaller fixes.
Aug 9, 2016
0a3ae51
Added flop counter.
Aug 9, 2016
7cef87b
Merge branch 'tmva_gpu' of github.com:simonpf/root into tmva_gpu
Aug 9, 2016
3cb26ba
Fixed flop rate computation.
Aug 9, 2016
0b23d06
Testing different curand initialization.
Aug 11, 2016
df80c5e
Testing different parallelization scheme.
Aug 11, 2016
dcbf1c6
Minor fixes and modifications.
Aug 13, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions tmva/tmva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ set(headers4 TNeuron.h TSynapse.h TActivationChooser.h TActivation.h TActivation
TNeuronInputSqSum.h TNeuronInputAbs.h Types.h Ranking.h RuleFit.h RuleFitAPI.h IMethod.h MsgLogger.h
VariableTransformBase.h VariableIdentityTransform.h VariableDecorrTransform.h VariablePCATransform.h
VariableGaussTransform.h VariableNormalizeTransform.h VariableRearrangeTransform.h VariableTransform.h ROCCalc.h ROCCurve.h)
set(dnn_files src/DNN/Architectures/Reference.cxx
src/DNN/Architectures/Reference/DataLoader.cxx)
set(dnn_cuda_files src/DNN/Architectures/Cuda/ActivationFunctions.cu
src/DNN/Architectures/Cuda/Arithmetic.cu
src/DNN/Architectures/Cuda/Buffers.cxx
src/DNN/Architectures/Cuda/CudaMatrix.cu
src/DNN/Architectures/Cuda/DataLoader.cu
src/DNN/Architectures/Cuda/Dropout.cu
src/DNN/Architectures/Cuda/Initialization.cu
src/DNN/Architectures/Cuda/Kernels.cu
src/DNN/Architectures/Cuda/LossFunctions.cu
src/DNN/Architectures/Cuda/OutputFunctions.cu
src/DNN/Architectures/Cuda/Propagation.cu
src/DNN/Architectures/Cuda/Regularization.cu)

#---Need to suffix each header name by TMVA/ -----------------
foreach(hs headers1 headers2 headers3 headers4)
Expand All @@ -45,7 +59,18 @@ endforeach()

ROOT_GENERATE_DICTIONARY(G__TMVA ${theaders1} ${theaders2} ${theaders3} ${theaders4} MODULE TMVA LINKDEF LinkDef.h OPTIONS "-writeEmptyRootPCM")

ROOT_LINKER_LIBRARY(TMVA *.cxx G__TMVA.cxx LIBRARIES Core
#---Handle CUDA dependent code. -----------------
find_package(CUDA)
if (CUDA_FOUND)
cuda_add_library(dnn_cuda ${dnn_cuda_files})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDNNCUDA")
set(cuda_libraries dnn_cuda ${CUDA_CUBLAS_LIBRARIES})
else (CUDA_FOUND)
set(cuda_libraries)
endif(CUDA_FOUND)

root_linker_library(TMVA *.cxx G__TMVA.cxx ${dnn_files}
LIBRARIES Core ${cuda_libraries}
DEPENDENCIES RIO Hist Tree TreePlayer MLP Minuit XMLIO)

install(DIRECTORY inc/TMVA/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/TMVA
Expand All @@ -63,7 +88,7 @@ if(NOT gnuinstall)
PATTERN "data" EXCLUDE)
endif()

#ROOT_ADD_TEST_SUBDIRECTORY(test)
ROOT_ADD_TEST_SUBDIRECTORY(test/DNN)



Expand Down
296 changes: 296 additions & 0 deletions tmva/tmva/inc/TMVA/DNN/Architectures/Cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
// @(#)root/tmva/tmva/dnn:$Id$
// Author: Simon Pfreundschuh 05/07/16

/*************************************************************************
* Copyright (C) 2016, Simon Pfreundschuh *
* All rights reserved. *
* *
* For the licensing terms see $ROOTSYS/LICENSE. *
* For the list of contributors see $ROOTSYS/README/CREDITS. *
*************************************************************************/

///////////////////////////////////////////////////////////////
// Definition of the TCuda architecture, which provides an //
// implementation of the low-level functionality for neural //
// networks for the CUDA computing architectures. //
///////////////////////////////////////////////////////////////

#ifndef TMVA_DNN_ARCHITECTURES_CUDA
#define TMVA_DNN_ARCHITECTURES_CUDA

#include <utility>

#include "cuda.h"

#include "Cuda/Types.h"
#include "Cuda/Kernels.h"
#include "Cuda/Buffers.h"
#include "Cuda/DataLoader.h"
#include "Cuda/CudaMatrix.h"
#include "TMVA/DNN/DataLoader.h"

namespace TMVA
{
namespace DNN
{

/** The TCuda architecture class.
*
* Low-level interface class for CUDA computing architecture. Contains as
* public types the declaration of the scalar, matrix and data loader types
* for this architecture as well as the remaining functions in the low-level
* interface in the form of static members.
*/
class TCuda
{

public:

using Scalar_t = CudaDouble_t;
using Matrix_t = TCudaMatrix;
using DeviceBuffer_t = TCudaDeviceBuffer;
using HostBuffer_t = TCudaHostBuffer;
template <typename Data_t>
using DataLoader_t = TCudaDataLoader<Data_t>;

//____________________________________________________________________________
//
// Propagation
//____________________________________________________________________________

/** @name Forward Propagation
* Low-level functions required for the forward propagation of activations
* through the network.
*/
///@{
/** Matrix-multiply \p input with the transpose of \pweights and
* write the results into \p output. */
static void MultiplyTranspose(TCudaMatrix &output,
const TCudaMatrix &input,
const TCudaMatrix &weights);
/** Add the vectors biases row-wise to the matrix output */
static void AddRowWise(TCudaMatrix &output,
const TCudaMatrix &biases);
///@}

/** @name Backward Propagation
* Low-level functions required for the forward propagation of activations
* through the network.
*/
///@{
/** Perform the complete backward propagation step. If the provided
* \p activationGradientsBackward matrix is not empty, compute the
* gradients of the objective function with respect to the activations
* of the previous layer (backward direction).
* Also compute the weight and the bias gradients. Modifies the values
* in \p df and thus produces only a valid result, if it is applied the
* first time after the corresponding forward propagation has been per-
* formed. */
static void Backward(TCudaMatrix & activationGradientsBackward,
TCudaMatrix & weightGradients,
TCudaMatrix & biasGradients,
TCudaMatrix & df,
const TCudaMatrix & activationGradients,
const TCudaMatrix & weights,
const TCudaMatrix & activationBackward);
/** Adds a the elements in matrix B scaled by c to the elements in
* the matrix A. This is required for the weight update in the gradient
* descent step.*/
static void ScaleAdd(TCudaMatrix & A,
const TCudaMatrix & B,
Scalar_t beta = 1.0);

static void Copy(TCudaMatrix & B,
const TCudaMatrix & A);
///@}

//____________________________________________________________________________
//
// Activation Functions
//____________________________________________________________________________

/** @name Activation Functions
* For each activation function, the low-level interface contains two routines.
* One that applies the acitvation function to a matrix and one that evaluate
* the derivatives of the activation function at the elements of a given matrix
* and writes the results into the result matrix.
*/
///@{
static void Identity(TCudaMatrix & B);
static void IdentityDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void Relu(TCudaMatrix & B);
static void ReluDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void Sigmoid(TCudaMatrix & B);
static void SigmoidDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void Tanh(TCudaMatrix & B);
static void TanhDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void SymmetricRelu(TCudaMatrix & B);
static void SymmetricReluDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void SoftSign(TCudaMatrix & B);
static void SoftSignDerivative(TCudaMatrix & B,
const TCudaMatrix & A);

static void Gauss(TCudaMatrix & B);
static void GaussDerivative(TCudaMatrix & B,
const TCudaMatrix & A);
///@}

//____________________________________________________________________________
//
// Loss Functions
//____________________________________________________________________________

/** @name Loss Functions
* Loss functions compute a scalar value given the \p output of the network
* for a given training input and the expected network prediction \p Y that
* quantifies the quality of the prediction. For each function also a routing
* that computes the gradients (suffixed by Gradients) must be provided for
* the starting of the backpropagation algorithm.
*/
///@{

static CudaDouble_t MeanSquaredError(const TCudaMatrix &Y,
const TCudaMatrix &output);
static void MeanSquaredErrorGradients(TCudaMatrix & dY,
const TCudaMatrix &Y,
const TCudaMatrix &output);

/** Sigmoid transformation is implicitly applied, thus \p output should
* hold the linear activations of the last layer in the net. */
static CudaDouble_t CrossEntropy(const TCudaMatrix &Y,
const TCudaMatrix &output);

static void CrossEntropyGradients(TCudaMatrix & dY,
const TCudaMatrix & Y,
const TCudaMatrix & output);
///@}

//____________________________________________________________________________
//
// Output Functions
//____________________________________________________________________________

/** @name Output Functions
* Output functions transform the activations \p output of the
* output layer in the network to a valid prediction \p YHat for
* the desired usage of the network, e.g. the identity function
* for regression or the sigmoid transformation for two-class
* classification.
*/
///@{
static void Sigmoid(TCudaMatrix &YHat,
const TCudaMatrix & );
///@}

//____________________________________________________________________________
//
// Regularization
//____________________________________________________________________________

/** @name Regularization
* For each regularization type two functions are required, one named
* <tt><Type>Regularization</tt> that evaluates the corresponding
* regularization functional for a given weight matrix and the
* <tt>Add<Type>RegularizationGradients</tt>, that adds the regularization
* component in the gradients to the provided matrix.
*/
///@{

static CudaDouble_t L1Regularization(const TCudaMatrix & W);
static void AddL1RegularizationGradients(TCudaMatrix & A,
const TCudaMatrix & W,
CudaDouble_t weightDecay);

static CudaDouble_t L2Regularization(const TCudaMatrix & W);
static void AddL2RegularizationGradients(TCudaMatrix & A,
const TCudaMatrix & W,
CudaDouble_t weightDecay);
///@}

//____________________________________________________________________________
//
// Initialization
//____________________________________________________________________________

/** @name Initialization
* For each initialization method, one function in the low-level interface
* is provided. The naming scheme is <p>Initialize<Type></p> for a given
* initialization method Type.
*/
///@{

static void InitializeGauss(TCudaMatrix & A);
static void InitializeUniform(TCudaMatrix & A);
static void InitializeIdentity(TCudaMatrix & A);
static void InitializeZero(TCudaMatrix & A);

///@}

//____________________________________________________________________________
//
// Dropout
//____________________________________________________________________________

/** @name Dropout
*/
///@{

/** Apply dropout with activation probability \p p to the given
* matrix \p A and scale the result by reciprocal of \p p. */
static void Dropout(TCudaMatrix & A, CudaDouble_t p);

///@}

//____________________________________________________________________________
//
// Additional Arithmetic Functions
//____________________________________________________________________________

/** @name Additional Arithmetic Functions
*
* Additional arithmetic on CUDA matrices used to implement the low-level
* interface.
*/
///@{

/** Standard multiplication of two matrices \p A and \p B with the result being
* written into C.
*/
static void Multiply(TCudaMatrix &C,
const TCudaMatrix &A,
const TCudaMatrix &B);
/** Matrix multiplication of two matrices \p A and \p B^T (transposed) with the
* result being written into C.
*/
static void TransposeMultiply(TCudaMatrix &output,
const TCudaMatrix &input,
const TCudaMatrix &Weights);
/** In-place Hadamard (element-wise) product of matrices \p A and \p B
* with the result being written into \p A.
*/
static void Hadamard(TCudaMatrix &A,
const TCudaMatrix &B);

/** Sum columns of (m x n) matrixx \p A and write the results into the first
* m elements in \p A.
*/
static void SumColumns(TCudaMatrix &B, const TCudaMatrix &A);

/** Compute the sum of all elements in \p A */
static CudaDouble_t Sum(const TCudaMatrix &A);
};

} // namespace DNN
} // namespace TMVA

#endif
Loading