Skip to content

Commit

Permalink
Implement Tanh Gelu Approximation (#61439)
Browse files Browse the repository at this point in the history
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9)
  • Loading branch information
rdspring1 authored and pytorchmergebot committed Jan 28, 2022
1 parent e58d5b7 commit f499ab9
Show file tree
Hide file tree
Showing 48 changed files with 752 additions and 202 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
Expand Down
20 changes: 10 additions & 10 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) (
build_borrowing_binary_op(maybe_get_output(), grad, self);
}

TORCH_META_FUNC(gelu) (const Tensor & self) {
TORCH_META_FUNC(gelu) (const Tensor & self, int64_t approximate) {
build_unary_op(maybe_get_output(), self);
}

TORCH_META_FUNC(gelu_backward) (
const Tensor& grad, const Tensor& self
const Tensor& grad, const Tensor& self, int64_t approximate
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
Expand Down Expand Up @@ -324,37 +324,37 @@ bool use_mkldnn(const Tensor& input) {
}

TORCH_IMPL_FUNC(gelu_out_cpu) (
const Tensor& self, const Tensor& result
const Tensor& self, int64_t approximate, const Tensor& result
) {
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor y = itensor_from_tensor(result);
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
} else {
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate);
}
#else
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate);
#endif
}

TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
const Tensor& grad, const Tensor& self, const Tensor& grad_input
const Tensor& grad, const Tensor& self, int64_t approximate, const Tensor& grad_input
) {
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor grady = itensor_from_tensor(grad);
ideep::tensor gradx = itensor_from_tensor(grad_input);
ideep::eltwise_backward::compute(x, grady, gradx,
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
} else {
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate);
}
#else
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate);
#endif
}

Expand Down
19 changes: 17 additions & 2 deletions aten/src/ATen/native/Activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ struct TensorIteratorBase;
class TensorBase;
}

namespace at {
namespace Gelu {

// Keep this in sync with Gelu class in torch/nn/_gelu.py
// These constants control the approximation behavior of gelu functions.
enum Gelu {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};
} // namespace Gelu
} // namespace at

namespace at { namespace native {

using structured_activation_fn = void (*)(TensorIteratorBase&);
Expand All @@ -35,6 +48,8 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
using gelu_fn = void (*)(TensorIteratorBase&, int64_t);
using gelu_backward_fn = void (*)(TensorIteratorBase&, int64_t);

DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
Expand All @@ -43,8 +58,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(gelu_fn, GeluKernel);
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
Expand Down
173 changes: 127 additions & 46 deletions aten/src/ATen/native/cpu/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal
// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(TensorIteratorBase& it) {
void GeluKernelImpl(TensorIteratorBase& it, int64_t approximate) {
auto grain_size = at::internal::GRAIN_SIZE;
// Numbers based on benchmarking.
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
Expand All @@ -187,53 +187,134 @@ void GeluKernelImpl(TensorIteratorBase& it) {
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
grain_size = it.numel() / at::get_num_threads();
}
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
},
[&](Vec x_vec) {
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
},
grain_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
}
}

void GeluBackwardKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
void GeluBackwardKernelImpl(TensorIteratorBase& it, int64_t approximate) {
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kThreeVec(scalar_t(3));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_sq = x * x;
auto x_cube = x_sq * x;
auto inner = kBeta * (x + kKappa * x_cube);
auto tanh_inner = std::tanh(inner);

auto left = scalar_t(0.5) * x;
auto right = scalar_t(1) + tanh_inner;

auto left_derivative = scalar_t(0.5) * right;

auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
auto inner_derivative =
kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;

return dy * (left_derivative + right_derivative);
},
[&](Vec dy_vec, Vec x_vec) {
auto x_sq = x_vec * x_vec;
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec =
kBetaVec * (x_vec + kKappaVec * x_cube);
auto tanh_inner_vec = inner_vec.tanh();

auto left_vec = kPointFiveVec * x_vec;
auto right_vec = kOneVec + tanh_inner_vec;

auto left_derivative_vec = kPointFiveVec * right_vec;

auto tanh_derivative_vec =
kOneVec - tanh_inner_vec * tanh_inner_vec;
auto inner_derivative_vec =
kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
auto right_derivative_vec =
left_vec * tanh_derivative_vec * inner_derivative_vec;

return dy_vec * (left_derivative_vec + right_derivative_vec);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
}

void hardsigmoid_kernel(TensorIteratorBase& iter) {
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
}

TORCH_IMPL_FUNC(gelu_out_cuda) (
const Tensor& /*self*/, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this);
const Tensor& /*self*/, int64_t approximate, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this, approximate);
}

TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this);
const Tensor& /*grad*/, const Tensor& /*self*/, int64_t approximate, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this, approximate);
}

}} // namespace at::native
Loading

1 comment on commit f499ab9

@vadimkantorov
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, does it also support None along with 'none'?

Please sign in to comment.