diff --git a/include/nbla/cuda/function/dropout.hpp b/include/nbla/cuda/function/dropout.hpp index 84f8e5a4c..cebb37e08 100644 --- a/include/nbla/cuda/function/dropout.hpp +++ b/include/nbla/cuda/function/dropout.hpp @@ -28,9 +28,8 @@ template class DropoutCuda : public Dropout { public: typedef typename CudaType::type Tc; - explicit DropoutCuda(const Context &ctx, double p, int seed = -1, - bool output_mask = false) - : Dropout(ctx, T(p), seed, output_mask) { + explicit DropoutCuda(const Context &ctx, double p, int seed = -1) + : Dropout(ctx, T(p), seed) { cuda_set_device(std::stoi(ctx.device_id)); NBLA_CHECK(this->p_ >= 0., error_code::value, "p must be between 0.0 and 1.0"); @@ -55,11 +54,14 @@ template class DropoutCuda : public Dropout { protected: curandGenerator_t curand_generator_; + bool store_mask_for_recompute_ = false; virtual void setup_impl(const Variables &inputs, const Variables &outputs); virtual void forward_impl(const Variables &inputs, const Variables &outputs); virtual void backward_impl(const Variables &inputs, const Variables &outputs, const vector &propagate_down, const vector &accum); + virtual void setup_recompute_impl(const Variables &inputs, + const Variables &outputs); virtual void recompute_impl(const Variables &inputs, const Variables &outputs); diff --git a/src/nbla/cuda/function/generic/dropout.cu b/src/nbla/cuda/function/generic/dropout.cu index d9337c6eb..7f51cc3f0 100644 --- a/src/nbla/cuda/function/generic/dropout.cu +++ b/src/nbla/cuda/function/generic/dropout.cu @@ -31,6 +31,16 @@ __global__ void kernel_dropout_forward(const int size, const float scale, } } +template +__global__ void kernel_dropout_recompute(const int size, const float scale, + const float p, const T *x, T *y, + const float *m) { + NBLA_CUDA_KERNEL_LOOP(s, size) { + // This operation is done when forward. m[s] = (m[s] > p) ? 1 : 0; + y[s] = x[s] * m[s] * scale; + } +} + template __global__ void kernel_dropout_backward(const int size, const float scale, const T *dy, const float *m, T *dx) { @@ -42,12 +52,13 @@ __global__ void kernel_dropout_backward(const int size, const float scale, template void DropoutCuda::setup_impl(const Variables &inputs, const Variables &outputs) { - outputs[0]->reshape(inputs[0]->shape(), true); - if (this->output_mask_) { - outputs[1]->reshape(inputs[0]->shape(), true); - } else { - this->mask_.reshape(inputs[0]->shape(), true); - } + Dropout::setup_impl(inputs, outputs); +} + +template +void DropoutCuda::setup_recompute_impl(const Variables &inputs, + const Variables &outputs) { + store_mask_for_recompute_ = true; } template @@ -56,8 +67,8 @@ void DropoutCuda::forward_impl(const Variables &inputs, cuda_set_device(std::stoi(this->ctx_.device_id)); const Tc *x = inputs[0]->get_data_pointer(this->ctx_); Tc *y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); - Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_; - float *m = mask.cast_data_and_get_pointer(this->ctx_, true); + VariablePtr mask = this->mask_; + float *m = mask->cast_data_and_get_pointer(this->ctx_, true); curandGenerator_t &gen = this->seed_ == -1 ? SingletonManager::get()->curand_generator() : curand_generator_; @@ -69,12 +80,16 @@ void DropoutCuda::forward_impl(const Variables &inputs, template void DropoutCuda::recompute_impl(const Variables &inputs, const Variables &outputs) { + NBLA_CHECK(this->mask_->data()->array()->get_num_arrays(), + error_code::unclassified, + "The mask of Dropout must be stored in mask_ for recomputation. " + "Please report this error to the NNabla developer team."); cuda_set_device(std::stoi(this->ctx_.device_id)); const Tc *x = inputs[0]->get_data_pointer(this->ctx_); Tc *y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); - Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_; - float *m = mask.cast_data_and_get_pointer(this->ctx_, true); - NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_dropout_forward, inputs[0]->size(), + VariablePtr mask = this->mask_; + const float *m = mask->get_data_pointer(this->ctx_); + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_dropout_recompute, inputs[0]->size(), this->scale_, this->p_, x, y, m); } @@ -89,8 +104,8 @@ void DropoutCuda::backward_impl(const Variables &inputs, cuda_set_device(std::stoi(this->ctx_.device_id)); Tc *dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); const Tc *dy = outputs[0]->get_grad_pointer(this->ctx_); - Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_; - const float *m = mask.get_data_pointer(this->ctx_); + VariablePtr mask = this->mask_; + const float *m = mask->get_data_pointer(this->ctx_); if (accum[0]) { NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_dropout_backward), inputs[0]->size(), this->scale_, dy, m, dx); @@ -98,5 +113,7 @@ void DropoutCuda::backward_impl(const Variables &inputs, NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_dropout_backward), inputs[0]->size(), this->scale_, dy, m, dx); } + + this->clear_buffer(); } }