Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: forward(clear_no_need_grad=True) and nn.grad of Dropout #313

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/nbla/cuda/cudnn/function/gru.hpp
Expand Up @@ -36,6 +36,7 @@ template <typename T> class GRUCudaCudnn : public GRU<T> {
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}
virtual bool grad_depends_output_data(int i, int o) const { return true; }

protected:
int seq_len_;
Expand Down Expand Up @@ -92,6 +93,7 @@ template <typename T> class GRUCudaCudnn : public GRU<T> {
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
virtual bool grad_depends_input_data_impl(int i, int o) const { return true; }
};
}
#endif
2 changes: 2 additions & 0 deletions include/nbla/cuda/cudnn/function/lstm.hpp
Expand Up @@ -36,6 +36,7 @@ template <typename T> class LSTMCudaCudnn : public LSTM<T> {
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}
virtual bool grad_depends_output_data(int i, int o) const { return true; }

protected:
int seq_len_;
Expand Down Expand Up @@ -92,6 +93,7 @@ template <typename T> class LSTMCudaCudnn : public LSTM<T> {
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
virtual bool grad_depends_input_data_impl(int i, int o) const { return true; }
};
}
#endif
1 change: 1 addition & 0 deletions include/nbla/cuda/cudnn/function/relu.hpp
Expand Up @@ -59,6 +59,7 @@ template <typename T> class ReLUCudaCudnn : public ReLUCuda<T> {
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
virtual bool grad_depends_input_data_impl(int i, int j) const { return true; }
};
}
#endif
2 changes: 2 additions & 0 deletions include/nbla/cuda/cudnn/function/rnn.hpp
Expand Up @@ -94,6 +94,7 @@ template <typename T> class RNNCudaCudnn : public RNN<T> {
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}
virtual bool grad_depends_output_data(int i, int o) const { return true; }

protected:
int seq_len_;
Expand Down Expand Up @@ -150,6 +151,7 @@ template <typename T> class RNNCudaCudnn : public RNN<T> {
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
virtual bool grad_depends_input_data_impl(int i, int o) const { return true; }
};
}
#endif
2 changes: 1 addition & 1 deletion include/nbla/cuda/function/add_scalar.hpp
Expand Up @@ -22,6 +22,6 @@ namespace nbla {

/** @copydoc AddScalar
*/
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(AddScalar, double);
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(AddScalar, double, false);
}
#endif
2 changes: 1 addition & 1 deletion include/nbla/cuda/function/bc_add2.hpp
Expand Up @@ -22,6 +22,6 @@ namespace nbla {

/** @copydoc BcAdd2
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(BcAdd2);
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(BcAdd2, false);
}
#endif
3 changes: 2 additions & 1 deletion include/nbla/cuda/function/div2.hpp
Expand Up @@ -22,6 +22,7 @@ namespace nbla {

/** @copydoc Div2
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Div2);
// In-placing is obsoleted.
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Div2, true);
}
#endif
5 changes: 3 additions & 2 deletions include/nbla/cuda/function/dropout.hpp
Expand Up @@ -27,8 +27,9 @@ template <typename T> class DropoutCuda : public Dropout<T> {
public:
typedef typename CudaType<T>::type Tc;

explicit DropoutCuda(const Context &ctx, double p, int seed = -1)
: Dropout<T>(ctx, T(p), seed) {
explicit DropoutCuda(const Context &ctx, double p, int seed = -1,
bool output_mask = false)
: Dropout<T>(ctx, T(p), seed, output_mask) {
cuda_set_device(std::stoi(ctx.device_id));
NBLA_CHECK(this->p_ > 0., error_code::value,
"p must be between 0.0 and 1.0");
Expand Down
3 changes: 2 additions & 1 deletion include/nbla/cuda/function/mul2.hpp
Expand Up @@ -22,6 +22,7 @@ namespace nbla {

/** @copydoc Mul2
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Mul2);
// In-placing is obsoleted.
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Mul2, true);
}
#endif
2 changes: 1 addition & 1 deletion include/nbla/cuda/function/mul_scalar.hpp
Expand Up @@ -22,6 +22,6 @@ namespace nbla {

/** @copydoc MulScalar
*/
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(MulScalar, double);
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(MulScalar, double, false);
}
#endif
3 changes: 2 additions & 1 deletion include/nbla/cuda/function/pow2.hpp
Expand Up @@ -22,6 +22,7 @@ namespace nbla {

/** @copydoc Pow2
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Pow2);
// In-placing is obsoleted.
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Pow2, true);
}
#endif
3 changes: 2 additions & 1 deletion include/nbla/cuda/function/pow_scalar.hpp
Expand Up @@ -22,6 +22,7 @@ namespace nbla {

/** @copydoc PowScalar
*/
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(PowScalar, double);
// In-placing is obsoleted.
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(PowScalar, double, true);
}
#endif
2 changes: 1 addition & 1 deletion include/nbla/cuda/function/sub2.hpp
Expand Up @@ -22,6 +22,6 @@ namespace nbla {

/** @copydoc Sub2
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Sub2);
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(Sub2, false);
}
#endif
4 changes: 2 additions & 2 deletions include/nbla/cuda/function/utils/base_transform_binary.hpp
Expand Up @@ -81,11 +81,11 @@ protected: \
virtual bool grad_depends_input_data_impl(int i, int j) const; \
}

#define NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(NAME) \
#define NBLA_DECLARE_TRANSFORM_BINARY_CUDA_INPLACE(NAME, IGNORE_INPLACE) \
template <typename T> class NAME##Cuda : public TransformBinaryCuda<T> { \
NBLA_DECLARE_TRANSFORM_BINARY_CUDA_CLASS_COMMON(NAME) \
explicit NAME##Cuda(const Context &ctx, bool inplace) \
: TransformBinaryCuda<T>(ctx, inplace) {} \
: TransformBinaryCuda<T>(ctx, (IGNORE_INPLACE) ? false : inplace) {} \
virtual shared_ptr<Function> copy() const { \
return create_##NAME(this->ctx_, this->inplace_); \
} \
Expand Down
5 changes: 3 additions & 2 deletions include/nbla/cuda/function/utils/base_transform_unary.hpp
Expand Up @@ -105,11 +105,12 @@ protected: \
virtual bool grad_depends_input_data_impl(int i, int j) const; \
}

#define NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(NAME, A0) \
#define NBLA_DECLARE_TRANSFORM_UNARY_CUDA_1_INPLACE(NAME, A0, IGNORE_INPLACE) \
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_CLASS_BEGIN_N(NAME, A0) { \
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_CLASS_COMMON(NAME); \
explicit NAME##Cuda(const Context &ctx, const A0 &a0, bool inplace) \
: TransformUnaryCuda<T, A0>(ctx, inplace, a0) {} \
: TransformUnaryCuda<T, A0>(ctx, (IGNORE_INPLACE) ? false : inplace, \
a0) {} \
virtual shared_ptr<Function> copy() const { \
return create_##NAME(this->ctx_, std::get<0>(this->args_), \
this->inplace_); \
Expand Down
4 changes: 2 additions & 2 deletions src/nbla/cuda/function/generic/div2.cu
Expand Up @@ -21,7 +21,7 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_BINARY_CUDA(Div2, x0 / x1, dy / x1,
dy *(-(inplace ? y *x1 : x0) / (x1 * x1)),
// Inplacing is obsoleted.
NBLA_DEFINE_TRANSFORM_BINARY_CUDA(Div2, x0 / x1, dy / x1, dy *(-x0 / (x1 * x1)),
false, false, true, true);
}
12 changes: 8 additions & 4 deletions src/nbla/cuda/function/generic/dropout.cu
Expand Up @@ -43,7 +43,11 @@ template <typename T>
void DropoutCuda<T>::setup_impl(const Variables &inputs,
const Variables &outputs) {
outputs[0]->reshape(inputs[0]->shape(), true);
this->mask_.reshape(inputs[0]->shape(), true);
if (this->output_mask_) {
outputs[1]->reshape(inputs[0]->shape(), true);
} else {
this->mask_.reshape(inputs[0]->shape(), true);
}
}

template <class T>
Expand All @@ -52,7 +56,7 @@ void DropoutCuda<T>::forward_impl(const Variables &inputs,
cuda_set_device(std::stoi(this->ctx_.device_id));
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_);
Tc *y = outputs[0]->cast_data_and_get_pointer<Tc>(this->ctx_, true);
Variable &mask = this->mask_;
Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_;
float *m = mask.cast_data_and_get_pointer<float>(this->ctx_, true);
curandGenerator_t &gen =
this->seed_ == -1 ? SingletonManager::get<Cuda>()->curand_generator()
Expand All @@ -68,7 +72,7 @@ void DropoutCuda<T>::recompute_impl(const Variables &inputs,
cuda_set_device(std::stoi(this->ctx_.device_id));
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_);
Tc *y = outputs[0]->cast_data_and_get_pointer<Tc>(this->ctx_, true);
Variable &mask = this->mask_;
Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_;
float *m = mask.cast_data_and_get_pointer<float>(this->ctx_, true);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_dropout_forward, inputs[0]->size(),
this->scale_, this->p_, x, y, m);
Expand All @@ -85,7 +89,7 @@ void DropoutCuda<T>::backward_impl(const Variables &inputs,
cuda_set_device(std::stoi(this->ctx_.device_id));
Tc *dx = inputs[0]->cast_grad_and_get_pointer<Tc>(this->ctx_, !accum[0]);
const Tc *dy = outputs[0]->get_grad_pointer<Tc>(this->ctx_);
Variable &mask = this->mask_;
Variable &mask = this->output_mask_ ? *outputs[1] : this->mask_;
const float *m = mask.get_data_pointer<float>(this->ctx_);
if (accum[0]) {
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_dropout_backward<Tc, true>),
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/exp.cu
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Exp, std::exp(x), dy *exp(x), false, true);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Exp, std::exp(x), y *dy, true, false);
}
13 changes: 7 additions & 6 deletions src/nbla/cuda/function/generic/leaky_relu.cu
Expand Up @@ -36,16 +36,16 @@ __global__ void kernel_leaky_relu_forward(const int num, T *y, const T *x,
}

template <typename T, bool accum = true>
__global__ void kernel_leaky_relu_backward(const int num, T *dx, const T *x,
__global__ void kernel_leaky_relu_backward(const int num, T *dx, const T *sign,
const T *dy, float alpha) {
NBLA_CUDA_KERNEL_LOOP(idx, num) {
if (accum) {
if (x[idx] > 0)
if (sign[idx] > 0)
dx[idx] += dy[idx];
else
dx[idx] += alpha * dy[idx];
} else {
if (x[idx] > 0)
if (sign[idx] > 0)
dx[idx] = dy[idx];
else
dx[idx] = alpha * dy[idx];
Expand Down Expand Up @@ -74,17 +74,18 @@ void LeakyReLUCuda<T>::backward_impl(const Variables &inputs,
return;
}
cuda_set_device(std::stoi(this->ctx_.device_id));
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_);
const Tc *sign = this->inplace_ ? outputs[0]->get_data_pointer<Tc>(this->ctx_)
: inputs[0]->get_data_pointer<Tc>(this->ctx_);
Tc *dx = inputs[0]->cast_grad_and_get_pointer<Tc>(
this->ctx_, !(this->inplace_ || accum[0]));
const Tc *dy = outputs[0]->get_grad_pointer<Tc>(this->ctx_);
size_t size = inputs[0]->size();
if (dx != dy && accum[0]) {
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_leaky_relu_backward<Tc, true>), size,
dx, x, dy, this->alpha_);
dx, sign, dy, this->alpha_);
} else {
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_leaky_relu_backward<Tc, false>),
size, dx, x, dy, this->alpha_);
size, dx, sign, dy, this->alpha_);
}
}
}
4 changes: 2 additions & 2 deletions src/nbla/cuda/function/generic/mul2.cu
Expand Up @@ -19,7 +19,7 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_BINARY_CUDA(Mul2, x0 *x1, dy *x1,
inplace ? dy *y / x1 : dy *x0, false, false,
// Inplacing is obsoleted.
NBLA_DEFINE_TRANSFORM_BINARY_CUDA(Mul2, x0 *x1, dy *x1, dy *x0, false, false,
true, true);
}
11 changes: 5 additions & 6 deletions src/nbla/cuda/function/generic/pow2.cu
Expand Up @@ -21,10 +21,9 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_BINARY_CUDA(
Pow2, std::pow(x0, x1),
dy *x1 *std::pow(inplace ? std::pow(y, 1 / x1) : x0, x1 - (T)1),
dy *std::log(inplace ? std::pow(y, 1 / x1) : x0) *
std::pow(inplace ? std::pow(y, 1 / x1) : x0, x1),
false, false, true, true);
// Inplacing is obsoleted.
NBLA_DEFINE_TRANSFORM_BINARY_CUDA(Pow2, std::pow(x0, x1),
dy *x1 *std::pow(x0, x1 - (T)1),
dy *std::log(x0) * std::pow(x0, x1), false,
false, true, true);
}
5 changes: 2 additions & 3 deletions src/nbla/cuda/function/generic/pow_scalar.cu
Expand Up @@ -19,10 +19,9 @@

namespace nbla {

// Inplacing is obsoleted.
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(
PowScalar,
a0 == 0.5f ? std::sqrt(x) : a0 == -0.5f ? rsqrt(x) : std::pow(x, (T)a0),
dy *(T)a0 *std::pow((inplace ? std::pow(y, (T)1 / (T)a0) : x),
(T)a0 - (T)1),
false, true, double);
dy *(T)a0 *std::pow(x, (T)a0 - (T)1), false, true, double);
}
10 changes: 5 additions & 5 deletions src/nbla/cuda/function/generic/relu.cu
Expand Up @@ -28,10 +28,10 @@ __global__ void kernel_relu_forward(const int num, T *y, const T *x) {
}

template <typename T, bool accum = true>
__global__ void kernel_relu_backward(const int num, T *dx, const T *x,
__global__ void kernel_relu_backward(const int num, T *dx, const T *y,
const T *dy) {
NBLA_CUDA_KERNEL_LOOP(idx, num) {
dx[idx] = (accum ? dx[idx] : (T)0) + (x[idx] > 0 ? dy[idx] : (T)0);
dx[idx] = (accum ? dx[idx] : (T)0) + (y[idx] > 0 ? dy[idx] : (T)0);
}
}

Expand All @@ -55,17 +55,17 @@ void ReLUCuda<T>::backward_impl(const Variables &inputs,
return;
}
cuda_set_device(std::stoi(this->ctx_.device_id));
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_);
const Tc *y = outputs[0]->get_data_pointer<Tc>(this->ctx_);
Tc *dx = inputs[0]->cast_grad_and_get_pointer<Tc>(
this->ctx_, !(this->inplace_ || accum[0]));
const Tc *dy = outputs[0]->get_grad_pointer<Tc>(this->ctx_);
size_t size = inputs[0]->size();
if (dx != dy && accum[0]) {
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_relu_backward<Tc, true>), size, dx,
x, dy);
y, dy);
} else {
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_relu_backward<Tc, false>), size, dx,
x, dy);
y, dy);
}
}
}