diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index fbd9b6a19db..8c6c999d20c 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -4857,6 +4857,19 @@ TEST_F(AtenXlaTensorTest, TestCeluInPlace) { ExpectCounterChanged("xla::elu_", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestGelu) { + torch::Tensor input = + torch::rand({2, 3}, torch::TensorOptions(torch::kFloat)); + torch::Tensor output = torch::gelu(input); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::gelu(xla_input); + AllClose(output, xla_output); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::gelu", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestAddMatMul) { int in_channels = 32; int out_channels = 320; @@ -7809,6 +7822,19 @@ TEST_F(AtenXlaTensorTest, TestEluBackward) { }); } +TEST_F(AtenXlaTensorTest, TestGeluBackward) { + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::gelu(inputs[0]); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({2, 3}, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, testfn); + }); + ExpectCounterChanged("xla::gelu_backward", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) { double negative_slope = 0.01; auto testfn = [=](const std::vector& inputs) -> torch::Tensor { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 67c64be826e..9a6ec061206 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1327,6 +1327,18 @@ at::Tensor& AtenXlaType::ge_(at::Tensor& self, const at::Tensor& other) { return self; } +at::Tensor AtenXlaType::gelu(const at::Tensor& self) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::gelu(bridge::GetXlaTensor(self))); +} + +at::Tensor AtenXlaType::gelu_backward(const at::Tensor& grad, + const at::Tensor& self) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::gelu_backward( + bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self))); +} + at::Tensor AtenXlaType::gt(const at::Tensor& self, at::Scalar other) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 1d5803de871..a13e804a23f 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -410,6 +410,11 @@ class AtenXlaType { static at::Tensor& ge_(at::Tensor& self, const at::Tensor& other); + static at::Tensor gelu(const at::Tensor& self); + + static at::Tensor gelu_backward(const at::Tensor& grad, + const at::Tensor& self); + static at::Tensor gt(const at::Tensor& self, at::Scalar other); static at::Tensor gt(const at::Tensor& self, const at::Tensor& other); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 90ff539f3b8..6f6485d6d13 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -488,6 +488,22 @@ NodePtr EluBackward(const Value& grad_output, const Value& output, positive_output_branch, negative_output_branch); } +NodePtr Gelu(const Value& input) { + // input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + const xla::Shape& shape = input.shape(); + return input * ScalarOp(0.5, shape) * + (Erf(input * ScalarOp(M_SQRT1_2, shape)) + ScalarOp(1.0, shape)); +} + +NodePtr GeluBackward(const Value& grad, const Value& input) { + const float kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5; + const xla::Shape& shape = input.shape(); + NodePtr scratch = Erf(input * ScalarOp(M_SQRT1_2, shape)); + NodePtr dinput = Exp(input * input * ScalarOp(-0.5, shape)); + return grad * (ScalarOp(0.5, shape) * (ScalarOp(1.0, shape) + scratch) + + input * dinput * ScalarOp(kAlpha, shape)); +} + NodePtr Lshift(const Value& input, at::Scalar other) { return input * ScalarOp(pow(2, other.to()), input.shape()); } diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 2fa52cd27c3..23ae8afcb2f 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -178,6 +178,10 @@ NodePtr Elu(const Value& input, at::Scalar alpha, at::Scalar scale, NodePtr EluBackward(const Value& grad_output, const Value& output, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale); +NodePtr Gelu(const Value& input); + +NodePtr GeluBackward(const Value& grad, const Value& input); + NodePtr Lshift(const Value& input, at::Scalar other); NodePtr Lshift(const Value& input, const Value& other); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index f814e9b8b69..8639b696b75 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -455,6 +455,9 @@ class XLATensor { static XLATensor ge(const XLATensor& input, const XLATensor& other); static void ge_(XLATensor& input, const XLATensor& other); + static XLATensor gelu(const XLATensor& input); + static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input); + static XLATensor gt(const XLATensor& input, at::Scalar other); static void gt_(XLATensor& input, at::Scalar other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 6b5d3fc56cb..6ef860ec7f8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1059,6 +1059,16 @@ void XLATensor::ge_(XLATensor& input, const XLATensor& other) { input.SetIrValue(ir::MakeNode(cmp_result, input.dtype())); } +XLATensor XLATensor::gelu(const XLATensor& input) { + return input.CreateFrom(ir::ops::Gelu(input.GetIrValue())); +} + +XLATensor XLATensor::gelu_backward(const XLATensor& grad, + const XLATensor& input) { + return input.CreateFrom( + ir::ops::GeluBackward(grad.GetIrValue(), input.GetIrValue())); +} + XLATensor XLATensor::gt(const XLATensor& input, at::Scalar other) { return DispatchComparisonOp(at::aten::gt, input, other); }