Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Tanh Gelu Approximation #61439

Closed
wants to merge 64 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
2b37e03
Enable gelu_double_backward
rdspring1 Jul 6, 2021
42c08f8
Add approximation flag to Gelu
rdspring1 Jul 7, 2021
63dde63
Fix MKLDNN::Gelu signature
rdspring1 Jul 9, 2021
e6e8f53
Update test_nn with tanh gelu - functional
rdspring1 Jul 10, 2021
7d81b9f
Fix sample_inputs_gelu
rdspring1 Jul 10, 2021
c757aa9
Fix test_cpp_api_parity
rdspring1 Jul 10, 2021
4a4eab3
Update torch overrides
rdspring1 Jul 12, 2021
7af7a5c
Enable Gelu for Tensor Expressions
rdspring1 Jul 12, 2021
11dba85
Add Gelu to UNTRACEABLE_FUNCTIONALS - test/test_fx
rdspring1 Jul 12, 2021
57a292c
Update torch/onnx/symbolic_opset
rdspring1 Jul 12, 2021
7d2cd5a
Merge remote-tracking branch 'upstream/master' into fast_gelu
rdspring1 Jul 13, 2021
a2336fa
Add Gelu to Tensor Expressions skip list
rdspring1 Jul 14, 2021
d47613c
Fix flake8 and mypy errors
rdspring1 Jul 14, 2021
987fef8
Update onnx tests
rdspring1 Jul 14, 2021
755e671
Update NvFuser Gelu handling
rdspring1 Jul 14, 2021
8d2a2d1
Add gelu + gelu_backward to backward_compatibility allow list
rdspring1 Jul 14, 2021
5c0ac6f
Merge branch 'master' of https://github.com/pytorch/pytorch into fast…
rdspring1 Jul 14, 2021
69ae2dc
Fix mypy and clang-format
rdspring1 Jul 15, 2021
c33bd4f
Merge branch 'master' of https://github.com/pytorch/pytorch into fast…
rdspring1 Jul 16, 2021
356a239
Add approxmiate flag to test_autocast_nn_fp32
rdspring1 Sep 30, 2021
d3bdfeb
Default approximate parameter for gelu
rdspring1 Sep 30, 2021
c319bf0
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Oct 5, 2021
846b923
Replace pow(x, 3) with x*x*x
rdspring1 Oct 5, 2021
b55911d
Fix gelu in symbolic-script
rdspring1 Oct 5, 2021
b5b6805
Do not use mkldnn for cpu tanh gelu
rdspring1 Oct 6, 2021
c562705
Update onnx tests
rdspring1 Oct 6, 2021
da253a5
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Oct 6, 2021
9893d11
Do not use mkldnn for tanh gelu backward
rdspring1 Oct 7, 2021
477cce5
Implement Tanh Gelu in Tensor Expressions
rdspring1 Oct 7, 2021
5e46a1d
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Oct 7, 2021
6306700
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Oct 26, 2021
34de839
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Nov 6, 2021
5901ded
Implement "approximate" string argument for tanh gelu
rdspring1 Nov 7, 2021
668e828
Update TensorExpr and NvFuser Gelu implementations
rdspring1 Nov 10, 2021
4886fa1
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Nov 20, 2021
a517263
Onnx Fix
rdspring1 Nov 20, 2021
dd7a09e
Replace gelu with relu in TestAutodiffSubgraphSlicing
rdspring1 Nov 20, 2021
a5ae9ca
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Dec 23, 2021
060ae67
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Jan 12, 2022
2e7c467
Add tanh approximation to quantized gelu
rdspring1 Jan 12, 2022
c89faed
Replace normcdf with erf in cuda implementation
rdspring1 Jan 12, 2022
06737b7
Fix "bad dtype in CompareSelect" with tensorexpr
rdspring1 Jan 12, 2022
10110fb
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Jan 28, 2022
efb260c
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Jan 31, 2022
69782cf
Mark 'approximate' argument keyword-only
rdspring1 Jan 31, 2022
9751375
Revert keyword-only requirement
rdspring1 Feb 3, 2022
cc01e15
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Feb 3, 2022
69804dc
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Feb 4, 2022
56aeea5
Initial string keyword argument implementation
rdspring1 Feb 6, 2022
2155f91
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Feb 6, 2022
200675e
Fix gelu constructor
rdspring1 Feb 6, 2022
0fa24f1
Disable keyword requirement for gradcheck
rdspring1 Feb 7, 2022
3623af2
Remove test_gelu from test_nn.py
rdspring1 Feb 8, 2022
c2cd1d3
Add string argument support to nvfuser
rdspring1 Feb 8, 2022
4b51d26
Set Gelu to BUILT_IN_FUNC for torch fx
rdspring1 Feb 8, 2022
51fd8d5
Address nvfuser comments
rdspring1 Feb 8, 2022
5282914
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Feb 9, 2022
5f88e0f
Add string argument support to tensorexpr
rdspring1 Feb 9, 2022
3b4d6c7
Add keyword to onnx tests
rdspring1 Feb 9, 2022
b1c8447
Merge branch 'master' of github.com:rdspring1/pytorch into fast_gelu
rdspring1 Feb 9, 2022
3204232
Use opmath_t in Activation.cu
rdspring1 Feb 9, 2022
0b5f234
Added numpy reference to gelu OpInfo test
rdspring1 Feb 9, 2022
467434a
Forward AD formula for gelu activation backwards
rdspring1 Feb 9, 2022
fbd5e62
Add stats from SCIPY for reference_gelu test
rdspring1 Feb 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
Expand Down
22 changes: 12 additions & 10 deletions aten/src/ATen/native/Activation.cpp
Expand Up @@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) (
build_borrowing_binary_op(maybe_get_output(), grad, self);
}

TORCH_META_FUNC(gelu) (const Tensor & self) {
TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
build_unary_op(maybe_get_output(), self);
}

TORCH_META_FUNC(gelu_backward) (
const Tensor& grad, const Tensor& self
const Tensor& grad, const Tensor& self, c10::string_view approximate
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
Expand Down Expand Up @@ -324,37 +324,39 @@ bool use_mkldnn(const Tensor& input) {
}

TORCH_IMPL_FUNC(gelu_out_cpu) (
const Tensor& self, const Tensor& result
const Tensor& self, c10::string_view approximate, const Tensor& result
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor y = itensor_from_tensor(result);
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
} else {
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate_type);
}
#else
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate_type);
#endif
}

TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
const Tensor& grad, const Tensor& self, const Tensor& grad_input
const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor grady = itensor_from_tensor(grad);
ideep::tensor gradx = itensor_from_tensor(grad_input);
ideep::eltwise_backward::compute(x, grady, gradx,
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
} else {
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate_type);
}
#else
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate_type);
#endif
}

Expand Down
23 changes: 21 additions & 2 deletions aten/src/ATen/native/Activation.h
Expand Up @@ -14,6 +14,23 @@ class TensorBase;

namespace at { namespace native {

// These constants control the approximation behavior of gelu function.
enum GeluType {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};

static GeluType get_gelutype_enum(const c10::string_view approximate) {
if (approximate == "none") {
return GeluType::None;
} else if (approximate == "tanh") {
return GeluType::Tanh;
} else {
TORCH_CHECK(false, "approximate argument must be either none or tanh.");
}
}

using structured_activation_fn = void (*)(TensorIteratorBase&);
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);

Expand All @@ -35,6 +52,8 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);

DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
Expand All @@ -43,8 +62,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(gelu_fn, GeluKernel);
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
Expand Down
173 changes: 127 additions & 46 deletions aten/src/ATen/native/cpu/Activation.cpp
Expand Up @@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal
// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(TensorIteratorBase& it) {
void GeluKernelImpl(TensorIteratorBase& it, GeluType approximate) {
auto grain_size = at::internal::GRAIN_SIZE;
// Numbers based on benchmarking.
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
Expand All @@ -187,53 +187,134 @@ void GeluKernelImpl(TensorIteratorBase& it) {
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
grain_size = it.numel() / at::get_num_threads();
}
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
},
[&](Vec x_vec) {
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
},
grain_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
}
}

void GeluBackwardKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
void GeluBackwardKernelImpl(TensorIteratorBase& it, GeluType approximate) {
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kThreeVec(scalar_t(3));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_sq = x * x;
auto x_cube = x_sq * x;
auto inner = kBeta * (x + kKappa * x_cube);
auto tanh_inner = std::tanh(inner);

auto left = scalar_t(0.5) * x;
auto right = scalar_t(1) + tanh_inner;

auto left_derivative = scalar_t(0.5) * right;

auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
auto inner_derivative =
kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;

return dy * (left_derivative + right_derivative);
},
[&](Vec dy_vec, Vec x_vec) {
auto x_sq = x_vec * x_vec;
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec =
kBetaVec * (x_vec + kKappaVec * x_cube);
auto tanh_inner_vec = inner_vec.tanh();

auto left_vec = kPointFiveVec * x_vec;
auto right_vec = kOneVec + tanh_inner_vec;

auto left_derivative_vec = kPointFiveVec * right_vec;

auto tanh_derivative_vec =
kOneVec - tanh_inner_vec * tanh_inner_vec;
auto inner_derivative_vec =
kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
auto right_derivative_vec =
left_vec * tanh_derivative_vec * inner_derivative_vec;

return dy_vec * (left_derivative_vec + right_derivative_vec);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
}

void hardsigmoid_kernel(TensorIteratorBase& iter) {
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/Activation.cpp
Expand Up @@ -156,15 +156,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
}

TORCH_IMPL_FUNC(gelu_out_cuda) (
const Tensor& /*self*/, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this);
const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this, get_gelutype_enum(approximate));
}

TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this);
const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this, get_gelutype_enum(approximate));
}

}} // namespace at::native