Skip to content

Commit

Permalink
Implement "approximate" string argument for tanh gelu
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 11, 2021
1 parent 4cb3440 commit e716d65
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 13 deletions.
4 changes: 3 additions & 1 deletion test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6264,7 +6264,9 @@ TEST_F(AtenXlaTensorTest, TestCeluInPlace) {
TEST_F(AtenXlaTensorTest, TestGelu) {
torch::Tensor input =
torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
for (bool approximate : {false, true}) {
const int64_t kNone = 0;
const int64_t kTanh = 1;
for (auto approximate : {kNone, kTanh}) {
torch::Tensor output = torch::gelu(input, approximate);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,15 +1506,16 @@ at::Tensor XLANativeFunctions::ge(const at::Tensor& self,
XLATensor::ge(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, bool approximate) {
at::Tensor XLANativeFunctions::gelu(const at::Tensor& self,
int64_t approximate) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::gelu(bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
const at::Tensor& self,
bool approximate) {
int64_t approximate) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gelu_backward(
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
Expand Down
12 changes: 8 additions & 4 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,11 @@ NodePtr EluBackward(const Value& grad_output, const Value& output,
positive_output_branch, negative_output_branch);
}

NodePtr Gelu(const Value& input, bool approximate) {
NodePtr Gelu(const Value& input, xla::int64 approximate) {
ScopePusher ir_scope("aten::gelu");
const xla::Shape& shape = input.shape();
if (approximate) {
const auto kTanh = 1;
if (approximate == kTanh) {
// inner = math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(input, 3))
// input * 0.5 * (1.0 + torch.tanh(inner))
const float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
Expand All @@ -640,10 +641,13 @@ NodePtr Gelu(const Value& input, bool approximate) {
}
}

NodePtr GeluBackward(const Value& grad, const Value& input, bool approximate) {
NodePtr GeluBackward(const Value& grad, const Value& input,
xla::int64 approximate) {
ScopePusher ir_scope("aten::gelu_backward");
const xla::Shape& shape = input.shape();
if (approximate) {
const int64_t kNone = 0;
const int64_t kTanh = 1;
if (approximate == kTanh) {
const float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
auto beta = ScalarOp(kBeta, shape);
auto kappa = ScalarOp(0.044715, shape);
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ NodePtr EluBackward(const Value& grad_output, const Value& output,
const at::Scalar& alpha, const at::Scalar& scale,
const at::Scalar& input_scale);

NodePtr Gelu(const Value& input, bool approximate);
NodePtr Gelu(const Value& input, xla::int64 approximate);

NodePtr GeluBackward(const Value& grad, const Value& input, bool approximate);
NodePtr GeluBackward(const Value& grad, const Value& input,
xla::int64 approximate);

NodePtr Lshift(const Value& input, const at::Scalar& other);

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,9 @@ class XLATensor {

static XLATensor ge(const XLATensor& input, const XLATensor& other);

static XLATensor gelu(const XLATensor& input, bool approximate);
static XLATensor gelu(const XLATensor& input, xla::int64 approximate);
static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input,
bool approximate);
xla::int64 approximate);

static XLATensor ger(const XLATensor& input, const XLATensor& vec2);

Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,12 +1289,13 @@ XLATensor XLATensor::ge(const XLATensor& input, const XLATensor& other) {
return DispatchComparisonOp(at::aten::ge, input, other);
}

XLATensor XLATensor::gelu(const XLATensor& input, bool approximate) {
XLATensor XLATensor::gelu(const XLATensor& input, xla::int64 approximate) {
return input.CreateFrom(ir::ops::Gelu(input.GetIrValue(), approximate));
}

XLATensor XLATensor::gelu_backward(const XLATensor& grad,
const XLATensor& input, bool approximate) {
const XLATensor& input,
xla::int64 approximate) {
return input.CreateFrom(ir::ops::GeluBackward(
grad.GetIrValue(), input.GetIrValue(), approximate));
}
Expand Down

0 comments on commit e716d65

Please sign in to comment.